In [1]:
%load_ext autoreload
%autoreload 2

# Weak Supervision VAE

Implementing some version of [Weak Supervision Variational Auto-Encoder](https://openreview.net/forum?id=0oDzoRjrbj) by Tonolini et al. 

There are some differences with the paper especially around the decoders and its still a work-in-progress as its mainly tested around the census data at the moment.

In [2]:
import os
import sys

os.chdir("../..")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.metrics as metrics
import torch

import flippers
from examples.utils import load_wrench_dataset, MetricsUtil
from examples.Experiments._vae import WeakLabelVAE

M = {}

In [3]:
datasets = ["cdr", "yelp", "youtube", "census", "spouse", "basketball", "sms", "tennis"]
for dataset in datasets:
    data, weak_labels, polarities = load_wrench_dataset(dataset)
    train, dev, test = data
    L_train, L_dev, L_test = weak_labels
    class_balances = list(
        dev["label"].astype(int).value_counts(normalize=True).sort_index()
    )
    y_test = test["label"].astype(int).values
    Metrics = MetricsUtil(y_test, L_test)

    m = WeakLabelVAE(polarities=polarities, class_balances=class_balances)
    m.fit(L_train)
    M[dataset] = Metrics.score(m, name="WeakLabelVAE", plots=False)
    print(dataset, M[dataset])

Make sure L captures all possible values of each weak labelers.
Epoch [10/10]: 100%|██████████| 10/10 [00:08<00:00,  1.12it/s, Loss=5.4]


cdr {'F1': 0.665, 'Average_Precision': 0.631, 'AUC': 0.802, 'Accuracy': 0.748, 'Balanced_Accuracy': 0.755}


Make sure L captures all possible values of each weak labelers.
Epoch [3/3]: 100%|██████████| 3/3 [00:07<00:00,  2.51s/it, Loss=6.5]


yelp {'F1': 0.704, 'Average_Precision': 0.786, 'AUC': 0.763, 'Accuracy': 0.683, 'Balanced_Accuracy': 0.681}


Make sure L captures all possible values of each weak labelers.
Epoch [51/51]: 100%|██████████| 51/51 [00:06<00:00,  7.56it/s, Loss=13.3]


youtube {'F1': 0.0, 'Average_Precision': 0.349, 'AUC': 0.168, 'Accuracy': 0.528, 'Balanced_Accuracy': 0.5}


Make sure L captures all possible values of each weak labelers.
Epoch [8/8]: 100%|██████████| 8/8 [00:07<00:00,  1.06it/s, Loss=5.8]


census {'F1': 0.575, 'Average_Precision': 0.566, 'AUC': 0.81, 'Accuracy': 0.717, 'Balanced_Accuracy': 0.75}


Make sure L captures all possible values of each weak labelers.
Epoch [4/4]: 100%|██████████| 4/4 [00:07<00:00,  1.82s/it, Loss=5.6]


spouse {'F1': 0.456, 'Average_Precision': 0.275, 'AUC': 0.783, 'Accuracy': 0.89, 'Balanced_Accuracy': 0.744}


Make sure L captures all possible values of each weak labelers.
Epoch [5/5]: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it, Loss=4.2] 


basketball {'F1': 0.174, 'Average_Precision': 0.104, 'AUC': 0.516, 'Accuracy': 0.549, 'Balanced_Accuracy': 0.515}


Make sure L captures all possible values of each weak labelers.
Epoch [18/18]: 100%|██████████| 18/18 [00:07<00:00,  2.31it/s, Loss=1.6]


sms {'F1': 0.136, 'Average_Precision': 0.136, 'AUC': 0.5, 'Accuracy': 0.796, 'Balanced_Accuracy': 0.51}


Make sure L captures all possible values of each weak labelers.
Epoch [12/12]: 100%|██████████| 12/12 [00:06<00:00,  1.74it/s, Loss=7.0]

tennis {'F1': 0.808, 'Average_Precision': 0.786, 'AUC': 0.885, 'Accuracy': 0.86, 'Balanced_Accuracy': 0.864}





In [4]:
from examples.utils import dataset_to_metric

df = pd.DataFrame(M)


# Formatting function to underline entries
def underline_entries(x):
    is_metric = x.index == dataset_to_metric[x.name]
    return ["text-decoration: underline" if v else "" for v in is_metric]


# Underlines the metrics used in the wrench benchmark
df.style.apply(underline_entries).format("{:.3f}")

Unnamed: 0,cdr,yelp,youtube,census,spouse,basketball,sms,tennis
F1,0.665,0.704,0.0,0.575,0.456,0.174,0.136,0.808
Average_Precision,0.631,0.786,0.349,0.566,0.275,0.104,0.136,0.786
AUC,0.802,0.763,0.168,0.81,0.783,0.516,0.5,0.885
Accuracy,0.748,0.683,0.528,0.717,0.89,0.549,0.796,0.86
Balanced_Accuracy,0.755,0.681,0.5,0.75,0.744,0.515,0.51,0.864
