<a href="https://colab.research.google.com/github/dantheman625/nlp_doc_info_extraction/blob/final_touch/5_Complete_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
!pip install seqeval scikit-learn datasets wandb nltk

In [None]:
import torch
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForTokenClassification,
    AutoModelForSequenceClassification,
    LongformerTokenizerFast,
    pipeline
)
from datasets import Dataset
import numpy as np
import os
import json
from seqeval.metrics import precision_score as ner_prec, recall_score as ner_rec, f1_score as ner_f1
from sklearn.metrics import classification_report, precision_recall_fscore_support

In [None]:
base_path = 'project_files'

## Wandb login

In [None]:
import wandb
wandb.login()

# Datasets

Import Challenge data set (Final_eval.json)


## Mount Drive

In [None]:
from google.colab import drive
import os
import json

drive.mount('/content/drive')

## Set Project folder in Google Drive

In [None]:
base_path = "/content/drive/MyDrive/project_files"


## Load file

In [None]:
eval_path   = os.path.join(base_path, '/data/processed/Final_eval.json')

eval_data = []
folder_path = f'{base_path}/data/raw/dev'

print(folder_path)

for root, dirs, files in os.walk(folder_path):
    for file_name in files:
        with open(f"{folder_path}/{file_name}", "r") as f:
            data = json.load(f)

        for d in data:
          eval_data.append(d)

dataset = Dataset.from_list(eval_data)
print("Sample example:")
print(dataset[0])

In [None]:
entity_labels = dataset[0]['entity_label_set']
label_list = ['O'] + [f"B-{l}" for l in entity_labels] + [f"I-{l}" for l in entity_labels]
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for l, i in label2id.items()}

# Define models

# Baseline models
Define which model you used as a baseline model for the specific task

In [None]:
baseline_ner_name = "allenai/longformer-base-4096"
baseline_re_name = "SpanBERT/spanbert-large-cased"


# Trained models

Define your trained model for the specific task

In [None]:
trained_ner_name = f"{base_path}/checkpoints/NER/longformer_tuned"
trained_re_name = f"{base_path}/checkpoints/RE/checkpoint-306"

# Model selection

Which model for NER, which for RE? -> Combination untrained/ untrained, trained/ trained, untrained/ trained, trained/ untrained

## Both baseline

In [None]:
ner_model_name = baseline_ner_name
re_model_name  = baseline_re_name

## Both trained

In [None]:
ner_model_name = trained_ner_name
re_model_name  = trained_re_name

## NER: trained, RE: baseline

In [None]:
ner_model_name = trained_ner_name
re_model_name  = baseline_re_name

## NER: baseline, RE: trained

In [None]:
ner_model_name = baseline_ner_name
re_model_name  = trained_re_name

# Load Models and Tokenizer

## NER

In [None]:
ner_tokenizer = LongformerTokenizerFast.from_pretrained(baseline_ner_name)
ner_model     = AutoModelForTokenClassification.from_pretrained(
    ner_model_name
)

ner_pipe = pipeline(
    'ner',
    model=ner_model,
    tokenizer=ner_tokenizer,
    device=-1,
    aggregation_strategy='simple'
)


## RE

In [None]:
re_tokenizer  = AutoTokenizer.from_pretrained(re_model_name)
if re_model_name == baseline_re_name:
    cfg = AutoConfig.from_pretrained(
        re_model_name,
        num_labels=len(label2id),
        label2id=label2id,
        id2label=id2label
    )
    re_model = AutoModelForSequenceClassification.from_pretrained(
        re_model_name,
        config=cfg
    )
    print(f"Loaded baseline RE model '{re_model_name}' with overridden head size num_labels={re_model.config.num_labels}")
else:
    re_model = AutoModelForSequenceClassification.from_pretrained(re_model_name)
    print(f"Loaded trained RE model '{re_model_name}' with head size num_labels={re_model.config.num_labels}")


#Initialize Wandb

In [None]:
wandb.init(
    project="model-eval",
    name=f"eval_{ner_model_name.split('/')[-1]}_{re_model_name.split('/')[-1]}",
    config={
        "ner_model": ner_model_name,
        "re_model": re_model_name,
        "dataset": "Final_eval.json",
        "batch_size": 32,
        "max_length": 256,
        "seed": 42,
    }
)

# NER Eval

Output: Entity file -> content

In [None]:
ner_val_results = []
for idx, example in enumerate(eval_data):
    preds = ner_pipe(example['doc'])
    ner_val_results.append({
        'domain': example.get('domain'),
        'doc_title': example.get('title', f'doc_{idx}'),
        'entities': preds,
        'doc': example.get('doc')
    })

print(ner_val_results[0])

Print NER Output for Nina to check

In [None]:
with open(f'{base_path}/data/processed/ner_val_results.json','w') as f:
    json.dump(
        ner_val_results,
        f,
        default=lambda o: o.item() if isinstance(o, np.generic) else o
    )


In [None]:
gt_index = {(ex['domain'], ex['title']): ex for ex in eval_data}
pred_index = {(p['domain'], p['doc_title']): p for p in ner_val_results}

true_ner_labels = []
pred_ner_labels = []

for key, gt in gt_index.items():
    pred = pred_index.get(key)
    if pred is None:
        continue

    text = gt['doc']
    tokens = text.split()
    n = len(tokens)

    char2tok = {}
    offset = 0
    for i, tok in enumerate(tokens):
        start = text.find(tok, offset)
        end = start + len(tok)
        for c in range(start, end):
            char2tok[c] = i
        offset = end

    true_labels = ['O'] * n
    pred_labels = ['O'] * n

    for ent in gt['entities']:
        ent_type = ent['type']
        for mention in ent['mentions']:
            start = text.find(mention)
            while start != -1:
                end = start + len(mention)
                t0 = char2tok.get(start)
                t1 = char2tok.get(end-1)
                if t0 is not None and t1 is not None:
                    true_labels[t0] = f'B-{ent_type}'
                    for t in range(t0+1, t1+1):
                        true_labels[t] = f'I-{ent_type}'
                start = text.find(mention, end)

    for ent in pred['entities']:
        t0 = char2tok.get(ent['start'])
        t1 = char2tok.get(ent['end'] - 1)
        et = ent['entity_group']
        if t0 is not None and t1 is not None:
            pred_labels[t0] = f'B-{et}'
            for t in range(t0+1, t1+1):
                pred_labels[t] = f'I-{et}'

    true_ner_labels.append(true_labels)
    pred_ner_labels.append(pred_labels)


## Log Metrics in Wandb

In [None]:
prec_ner = ner_prec(true_ner_labels, pred_ner_labels)
rec_ner  = ner_rec(true_ner_labels, pred_ner_labels)
f1_ner   = ner_f1(true_ner_labels, pred_ner_labels)

print(prec_ner)
print(rec_ner)
print(f1_ner)

wandb.log({
    "ner/precision": prec_ner,
    "ner/recall":    rec_ner,
    "ner/f1":        f1_ner,
})


# RE Eval

Input: Entity file, original challenge test file -> matching of entities to sentences (siehe wa) -> Liste mit dict

## Preprocessing

In [None]:
mapping_challenge_to_docred = {
    "Affiliation":                         "member of",
    "ApprovedBy":                          "ApprovedBy",
    "Author":                              "author",
    "AwardReceived":                       "award received",
    "BasedOn":                             "BasedOn",
    "Capital":                             "capital",
    "Causes":                              "Causes",
    "Continent":                           "continent",
    "ContributedToCreativeWork":           "ContributedToCreativeWork",
    "Country":                             "country",
    "CountryOfCitizenship":                "country of citizenship",
    "Creator":                             "creator",
    "Developer":                           "developer",
    "DifferentFrom":                       "DifferentFrom",
    "Director":                            "director",
    "EducatedAt":                          "educated at",
    "Employer":                            "employer",
    "FieldOfWork":                         "FieldOfWork",
    "FollowedBy":                          "followed by",
    "Follows":                             "follows",
    "Founded":                             "founded",
    "FoundedBy":                           "founded by",
    "HasCause":                            "HasCause",
    "HasEffect":                           "HasEffect",
    "HasPart":                             "HasPart",
    "HasWorksInTheCollection":             "HasWorksInTheCollection",
    "InfluencedBy":                        "influenced by",
    "IssuedBy":                            "IssuedBy",
    "LocatedIn":                           "located in the administrative territorial entity",
    "Location":                            "location",
    "MemberOf":                            "member of",
    "NamedBy":                             "NamedBy",
    "NominatedFor":                        "nominated for",
    "OfficialLanguage":                    "official language",
    "OwnedBy":                             "owned by",
    "OwnerOf":                             "owner of",
    "ParentOrganization":                  "parent organization",
    "PartOf":                              "part of",
    "Partner":                             "partner",
    "PlaceOfBirth":                        "place of birth",
    "PositionHeld":                        "position held",
    "PublishedIn":                         "PublishedIn",
    "Replaces":                            "replaces",
    "SaidToBeTheSameAs":                   "SaidToBeTheSameAs",
    "Studies":                             "Studies",
    "UsedBy":                              "UsedBy",
    "Uses":                                "Uses",
    "WorkLocation":                        "work location",

    "LanguageOfWorkOrName":                "original language of work",
    "LanguageUsed":                        "languages spoken, written or signed",
    "OriginalLanguageOfFilmOrTvShow":      "original language of work",
    "PartyChiefRepresentative":            "head of government",
    "PrimeFactor":                         "part of",
    "TwinnedAdministrativeBody":           "sister city",

    "AcademicDegree":                      "educated at",
    "AdjacentStation":                     "shares border with",
    "AppliesToPeople":                     "applies to jurisdiction",
    "CitesWork":                           "present in work",
    "ContainsAdministrativeTerritorialEntity":     "contains administrative territorial entity",
    "ContainsTheAdministrativeTerritorialEntity":  "contains administrative territorial entity",
    "DiplomaticRelation":                  "conflict",
    "HasQuality":                          "genre",
    "InOppositionTo":                      "separated from",
    "InspiredBy":                          "BasedOn",
    "InterestedIn":                        "Studies",
    "NamedAfter":                          "NamedBy",
    "NativeLanguage":                      "languages spoken, written or signed",
    "OperatingSystem":                     "platform",
    "PhysicallyInteractsWith":             "shares border with",
    "PracticedBy":                         "UsedBy",
    "PresentedIn":                         "present in work",
    "Promoted":                            "HasEffect",
    "RegulatedBy":                         "IssuedBy",
    "SharesBorderWith":                    "shares border with",
    "SignificantEvent":                    "location",
}



In [None]:
import nltk
from nltk.tokenize import sent_tokenize
from collections import defaultdict

nltk.download("punkt")
nltk.download('punkt_tab')

candidates = []
for doc in ner_val_results:
    text  = doc["doc"]
    title = doc["doc_title"].strip().lower()

    sentences = sent_tokenize(text)
    offsets = []
    cursor = 0
    for s in sentences:
        start = text.find(s, cursor)
        end   = start + len(s)
        offsets.append((s, start, end))
        cursor = end

    mentions = [
      (ent["word"].strip(), ent["start"], ent["end"], ent["entity_group"])
      for ent in doc["entities"]
    ]

    for sent, s_start, s_end in offsets:
        sent_mentions = [
          (w, a, b, label)
          for (w,a,b,label) in mentions
          if s_start <= a < s_end
        ]

        if len(sent_mentions) < 2:
            continue

        for i in range(len(sent_mentions)):
            w1, a1, b1, label1 = sent_mentions[i]
            for j in range(i+1, len(sent_mentions)):
                w2, a2, b2, label2 = sent_mentions[j]

                snippet = sent.replace(w1, f"[E1]{w1}[/E1]", 1) \
                              .replace(w2, f"[E2]{w2}[/E2]", 1)

                candidates.append({
                  "doc_title":       title,
                  "text":            snippet,
                  "entity1_label":   w1,
                  "entity2_label":   w2,
                  "relation_label":  "no_relation"
                })

from datasets import Dataset
ds = Dataset.from_list(candidates)

print(f"Built {len(candidates)} candidate pairs")
print(f"HF Dataset contains {len(ds)} rows")


re_val_ds = ds


## Trainer option

In [None]:
re_tokenizer = AutoTokenizer.from_pretrained(re_model_name)
if re_model_name == baseline_re_name:
    cfg = AutoConfig.from_pretrained(
        re_model_name,
        num_labels=len(label2id),
        label2id=label2id,
        id2label=id2label
    )
    re_model = AutoModelForSequenceClassification.from_pretrained(
        re_model_name, config=cfg
    )
else:
    re_model = AutoModelForSequenceClassification.from_pretrained(re_model_name)

def tokenize_fn(batch):
    return re_tokenizer(batch['text'],
                        padding='max_length',
                        truncation=True,
                        max_length=256)
tokenized_val = re_val_ds.map(tokenize_fn, batched=True)

from transformers import Trainer, TrainingArguments
eval_args = TrainingArguments(
    output_dir=f'{base_path}/data/processed/re_predict_output',
    per_device_eval_batch_size=32,
    do_train=False,
    do_eval=False,
    logging_dir=f'{base_path}/logs',
    report_to='wandb'
)
trainer = Trainer(
    model=re_model,
    args=eval_args,
    tokenizer=re_tokenizer
)

preds_output = trainer.predict(tokenized_val)
logits = preds_output.predictions
pred_ids = logits.argmax(axis=-1)

if hasattr(re_model.config, 'id2label') and re_model.config.id2label:
    pred_id2label = {int(k):v for k,v in re_model.config.id2label.items()}
else:
    pred_id2label = id2label

import json, os
safe = os.path.basename(re_model_name.rstrip('/'))
out_path = f'{base_path}/data/processed/re_{safe}_candidates_with_preds.json'
outputs = []
for ex, pid in zip(re_val_ds, pred_ids):
    outputs.append({
      'text':               ex['text'],
      'entity1_label':      ex['entity1_label'],
      'entity2_label':      ex['entity2_label'],
      'predicted_relation': pred_id2label.get(int(pid), 'UNKNOWN')
    })
with open(out_path, 'w') as f:
    json.dump(outputs, f, indent=2)
print(f"Wrote {len(outputs)} predictions to {out_path}")


## Pipeline option

In [None]:
re_tokenizer = AutoTokenizer.from_pretrained(re_model_name)
if re_model_name == baseline_re_name:
    cfg = AutoConfig.from_pretrained(
        re_model_name,
        num_labels=len(label2id),
        label2id=label2id,
        id2label=id2label
    )
    re_model = AutoModelForSequenceClassification.from_pretrained(
        re_model_name, config=cfg
    )
    print(f"> Loaded baseline RE model '{re_model_name}' with head size {re_model.config.num_labels}")
else:
    re_model = AutoModelForSequenceClassification.from_pretrained(re_model_name)
    print(f"> Loaded trained RE model '{re_model_name}' with head size {re_model.config.num_labels}")

device = 0 if torch.cuda.is_available() else -1
re_pipe = pipeline(
    "text-classification",
    model=re_model,
    tokenizer=re_tokenizer,
    device=device,
    return_all_scores=False,
)

outputs = []
for ex in re_val_ds:
    pred = re_pipe(ex["text"])[0]
    outputs.append({
        "text":               ex["text"],
        "entity1_label":      ex["entity1_label"],
        "entity2_label":      ex["entity2_label"],
        "predicted_relation": pred["label"],
        "score":              pred["score"],
    })

safe = os.path.basename(re_model_name.rstrip("/"))
out_path = f"{base_path}/data/processed/re_{safe}_candidates_with_preds.json"
with open(out_path, "w") as f:
    json.dump(outputs, f, indent=2)

print(f"Wrote {len(outputs)} predictions to {out_path}")


In [None]:
gold_map = {}
for meta in dataset:
    for t in meta.get("triples", []):
        h = t["head"].strip().lower()
        te = t["tail"].strip().lower()
        mapped = mapping_challenge_to_docred.get(t["relation"])
        if not mapped:
            continue
        gold_map[(h, te)] = mapped
        gold_map[(te, h)] = mapped

predictions_path = out_path
with open(predictions_path, "r") as f:
    preds_list = json.load(f)

gold_labels = []
pred_labels = []

for ex in preds_list:
    e1 = ex["entity1_label"].strip().lower()
    e2 = ex["entity2_label"].strip().lower()
    gold = gold_map.get((e1, e2), "no_relation")
    gold_labels.append(gold)
    pred_labels.append(ex["predicted_relation"])

all_labels = sorted(set(gold_labels) | set(pred_labels))
print(classification_report(
    gold_labels,
    pred_labels,
    labels=all_labels,
    target_names=all_labels,
    zero_division=0
))

p, r, f1, _ = precision_recall_fscore_support(
    gold_labels, pred_labels, average="micro", zero_division=0
)
print(f"→ micro precision={p:.4f}   recall={r:.4f}   f1={f1:.4f}")

import json

detailed = []
for ex, gold, pred in zip(preds_list, gold_labels, pred_labels):
    detailed.append({
        "text":             ex["text"],
        "entity1_label":    ex["entity1_label"],
        "entity2_label":    ex["entity2_label"],
        "gold_relation":    gold,
        "predicted_relation": pred
    })

out_detail_path = f"{base_path}/data/processed/detailed_re_predictions.json"
with open(out_detail_path, "w") as f:
    json.dump(detailed, f, indent=2)

print(f"Wrote {len(detailed)} detailed examples to {out_detail_path}")


In [None]:
unique_labels = sorted(set(gold_labels) | set(pred_labels))

report = classification_report(
    gold_labels,
    pred_labels,
    labels=unique_labels,
    target_names=unique_labels,
    output_dict=True,
    zero_division=0
)
wandb.log({"classification_report": report})

prec_re, rec_re, f1_re, _ = precision_recall_fscore_support(
    gold_labels,
    pred_labels,
    labels=unique_labels,
    average="micro",
    zero_division=0
)
wandb.log({
    "re/precision": prec_re,
    "re/recall":    rec_re,
    "re/f1":        f1_re,
})

summary = wandb.Table(
    columns=[
      "ner_precision","ner_recall","ner_f1",
      "re_precision",  "re_recall",  "re_f1"
    ],
    data=[[prec_ner, rec_ner, f1_ner, prec_re, rec_re, f1_re]]
)
wandb.log({"metrics_summary": summary})

print(f"NER   → precision: {prec_ner:.4f}, recall: {rec_ner:.4f}, f1: {f1_ner:.4f}")
print(f"RE    → precision: {prec_re:.4f}, recall: {rec_re:.4f}, f1: {f1_re:.4f}")



# Exclude no relation
Since the RE model was only trained on docred labels and not no relation

In [None]:
with open(f"{base_path}/data/processed/detailed_re_predictions.json", "r") as f:
    detailed = json.load(f)

filtered = [ex for ex in detailed if ex["gold_relation"] != "no_relation"]

gold_filt = [ex["gold_relation"]      for ex in filtered]
pred_filt = [ex["predicted_relation"] for ex in filtered]

print(f"→ {len(filtered)} positive examples (out of {len(detailed)})\n")

labels = sorted(set(gold_filt) | set(pred_filt))
print(classification_report(
    gold_filt,
    pred_filt,
    labels=labels,
    target_names=labels,
    zero_division=0
))

p, r, f1, _ = precision_recall_fscore_support(
    gold_filt,
    pred_filt,
    average="micro",
    zero_division=0
)
print(f"→ POSITIVE-only micro precision={p:.4f}   recall={r:.4f}   f1={f1:.4f}")
out_detail_positive_only_path = f"{base_path}/data/processed/detailed_re_predictions_positive_only.json"
with open(out_detail_positive_only_path, "w") as f:
    json.dump(detailed, f, indent=2)

print(f"Wrote {len(detailed)} detailed examples to {out_detail_path}")



### Build RE validation examples

### Create HF Dataset

In [None]:
re_tokenizer  = AutoTokenizer.from_pretrained(re_model_name)
if re_model_name == baseline_re_name:
    cfg = AutoConfig.from_pretrained(
        re_model_name,
        num_labels=len(label2id),
        label2id=label2id,
        id2label=id2label
    )
    re_model = AutoModelForSequenceClassification.from_pretrained(
        re_model_name,
        config=cfg
    )
    print(f"Loaded baseline RE model '{re_model_name}' with overridden head size num_labels={re_model.config.num_labels}")
else:
    re_model = AutoModelForSequenceClassification.from_pretrained(re_model_name)
    print(f"Loaded trained RE model '{re_model_name}' with head size num_labels={re_model.config.num_labels}")

def tokenize_fn(batch):
    return re_tokenizer(
        batch['text'],
        padding='max_length',
        truncation=True,
        max_length=256
    )

tokenized_val = re_val_ds.map(tokenize_fn, batched=True)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    p, r, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'precision': p, 'recall': r, 'f1': f1}

eval_args = TrainingArguments(
    output_dir=f'{base_path}/data/processed/re_eval_output',
    per_device_eval_batch_size=32,
    do_train=False,
    do_eval=True,
    logging_dir=f'{base_path}/logs',
    report_to='wandb'
)
trainer = Trainer(
    model=re_model,
    args=eval_args,
    tokenizer=re_tokenizer,
    compute_metrics=compute_metrics
)

eval_result = trainer.evaluate(eval_dataset=tokenized_val)
print("🔍 RE Validation Results:", eval_result)

preds_output = trainer.predict(tokenized_val)
pred_ids = np.argmax(preds_output.predictions, axis=-1)

if hasattr(re_model.config, 'id2label') and re_model.config.id2label:
    pred_id2label = re_model.config.id2label
else:
    pred_id2label = id2label

import os
safe_model_name = os.path.basename(re_model_name.rstrip('/'))
output_path = f'{base_path}/data/processed/re_{safe_model_name}_predictions.json'

outputs = []
for ex, pred in zip(re_val_ds, pred_ids):
    pred_label = pred_id2label.get(pred, 'UNKNOWN')
    outputs.append({
        'text': ex['text'],
        'entity1_label': ex['entity1_label'],
        'entity2_label': ex['entity2_label'],
        'gold_relation': ex['relation_label'],
        'predicted_relation': pred_label
    })
with open(output_path, 'w') as f:
    json.dump(outputs, f, indent=2)
print(f"Wrote predictions to {output_path}")

### Tokenize Validation Examples

### Classification report

In [None]:
true_ids = re_val_ds['labels']
pred_ids = pred_ids

if hasattr(re_model.config, 'id2label') and re_model.config.id2label:
    model_id2label = { int(k):v for k,v in re_model.config.id2label.items() }
else:
    model_id2label = id2label

unique_labels = sorted(set(true_ids) | set(pred_ids))
target_names   = [ model_id2label[l] for l in unique_labels ]

report = classification_report(
    true_ids,
    pred_ids,
    labels=unique_labels,
    target_names=target_names,
    output_dict=True,
    zero_division=0
)
wandb.log({"classification_report": report})

prec_re, rec_re, f1_re, _ = precision_recall_fscore_support(
    true_ids,
    pred_ids,
    labels=unique_labels,
    average='micro'
)
wandb.log({
    "re/precision": prec_re,
    "re/recall":    rec_re,
    "re/f1":        f1_re,
})

summary_table = wandb.Table(
    columns=[
      "ner_precision","ner_recall","ner_f1",
      "re_precision", "re_recall", "re_f1"
    ],
    data=[[prec_ner, rec_ner, f1_ner, prec_re, rec_re, f1_re]]
)
wandb.log({"metrics_summary": summary_table})


In [None]:
equiv = {
    "HasPart": "HasPart",
    "part of": "HasPart",
    "BasedOn":      "BasedOn",
    "HasEffect":    "HasEffect",
    "Causes":       "Causes",
    "influenced by":"InfluencedBy",
    "InfluencedBy": "InfluencedBy",
}

gold_str = re_val_ds["relation_label"]
pred_str = [ pred_id2label.get(p, "UNKNOWN")
             for p in pred_ids ]

gold_norm = [ equiv[g] if g in equiv else g for g in gold_str ]
pred_norm = [ equiv[p] if p in equiv else p for p in pred_str ]

unique_labels_str = sorted(set(gold_norm) | set(pred_norm))

report = classification_report(
    gold_norm,
    pred_norm,
    labels=unique_labels_str,
    target_names=unique_labels_str,
    output_dict=True,
    zero_division=0
)
wandb.log({"classification_report": report})

prec_re, rec_re, f1_re, _ = precision_recall_fscore_support(
    gold_norm,
    pred_norm,
    labels=unique_labels_str,
    average='micro'
)
wandb.log({
    "re/precision": prec_re,
    "re/recall":    rec_re,
    "re/f1":        f1_re,
})

print(prec_re)
print(rec_re)
print(f1_re)


Wrap Up


In [None]:
wandb.finish()