In [1]:
%load_ext autoreload
%autoreload 2

# Weak Supervision VAE

Implementing some version of [Weak Supervision Variational Auto-Encoder](https://openreview.net/forum?id=0oDzoRjrbj) by Tonolini et al. 

There are some differences with the paper especially around the decoders and its still a work-in-progress as its mainly tested around the census data at the moment.

In [2]:
import os
import sys

os.chdir("../..")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.metrics as metrics
import torch

import flippers
from examples.utils import load_wrench_dataset, MetricsUtil
from examples.Experiments._vae import WeakLabelVAE

M = {}

In [5]:
datasets = ["cdr", "yelp", "youtube", "census", "spouse", "basketball", "sms", "tennis"]
for dataset in datasets:
    data, weak_labels, polarities = load_wrench_dataset(dataset)
    train, dev, test = data
    L_train, L_dev, L_test = weak_labels
    class_balances = list(
        dev["label"].astype(int).value_counts(normalize=True).sort_index()
    )
    y_test = test["label"].astype(int).values
    Metrics = MetricsUtil(y_test, L_test)

    m = WeakLabelVAE(polarities=polarities, class_balances=class_balances)
    m.fit(L_train)
    M[dataset] = Metrics.score(m, name="WeakLabelVAE", plots=False)
    print(dataset, M[dataset])

Make sure L captures all possible values of each weak labelers.
Epoch [10/10]: 100%|██████████| 10/10 [00:06<00:00,  1.45it/s, Loss=1.7]


cdr {'F1': 0.645, 'Average_Precision': 0.637, 'AUC': 0.825, 'Accuracy': 0.759, 'Balanced_Accuracy': 0.737}


Make sure L captures all possible values of each weak labelers.
Epoch [3/3]: 100%|██████████| 3/3 [00:07<00:00,  2.39s/it, Loss=5.1] 


yelp {'F1': 0.756, 'Average_Precision': 0.796, 'AUC': 0.777, 'Accuracy': 0.735, 'Balanced_Accuracy': 0.732}


Make sure L captures all possible values of each weak labelers.
Epoch [51/51]: 100%|██████████| 51/51 [00:06<00:00,  7.57it/s, Loss=7.6] 


youtube {'F1': 0.802, 'Average_Precision': 0.911, 'AUC': 0.882, 'Accuracy': 0.836, 'Balanced_Accuracy': 0.829}


Make sure L captures all possible values of each weak labelers.
Epoch [8/8]: 100%|██████████| 8/8 [00:07<00:00,  1.00it/s, Loss=2.2]


census {'F1': 0.616, 'Average_Precision': 0.511, 'AUC': 0.82, 'Accuracy': 0.765, 'Balanced_Accuracy': 0.776}


Make sure L captures all possible values of each weak labelers.
Epoch [4/4]: 100%|██████████| 4/4 [00:06<00:00,  1.74s/it, Loss=2.8]


spouse {'F1': 0.487, 'Average_Precision': 0.388, 'AUC': 0.803, 'Accuracy': 0.886, 'Balanced_Accuracy': 0.787}


Make sure L captures all possible values of each weak labelers.
Epoch [5/5]: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it, Loss=5.4] 


basketball {'F1': 0.174, 'Average_Precision': 0.108, 'AUC': 0.525, 'Accuracy': 0.549, 'Balanced_Accuracy': 0.515}


Make sure L captures all possible values of each weak labelers.
Epoch [18/18]: 100%|██████████| 18/18 [00:07<00:00,  2.26it/s, Loss=0.7]


sms {'F1': 0.051, 'Average_Precision': 0.15, 'AUC': 0.508, 'Accuracy': 0.852, 'Balanced_Accuracy': 0.505}


Make sure L captures all possible values of each weak labelers.
Epoch [12/12]: 100%|██████████| 12/12 [00:06<00:00,  1.78it/s, Loss=3.5]

tennis {'F1': 0.808, 'Average_Precision': 0.798, 'AUC': 0.886, 'Accuracy': 0.86, 'Balanced_Accuracy': 0.864}





In [6]:
from examples.utils import dataset_to_metric

df = pd.DataFrame(M)


# Formatting function to underline entries
def underline_entries(x):
    is_metric = x.index == dataset_to_metric[x.name]
    return ["text-decoration: underline" if v else "" for v in is_metric]


# Underlines the metrics used in the wrench benchmark
df.style.apply(underline_entries).format("{:.3f}")

Unnamed: 0,cdr,yelp,youtube,census,spouse,basketball,sms,tennis
F1,0.645,0.756,0.802,0.616,0.487,0.174,0.051,0.808
Average_Precision,0.637,0.796,0.911,0.511,0.388,0.108,0.15,0.798
AUC,0.825,0.777,0.882,0.82,0.803,0.525,0.508,0.886
Accuracy,0.759,0.735,0.836,0.765,0.886,0.549,0.852,0.86
Balanced_Accuracy,0.737,0.732,0.829,0.776,0.787,0.515,0.505,0.864
