In [2]:
%load_ext autoreload
%autoreload 2

# Tutorial: How to train a classifier using Weak Supervision?

In this tutorial, we are going to train a spam detection classifier using weakly supervised data. 

The steps:
- Collect training data
- Annotate this data in a weakly supervised setting
    - Create labeling functions
    - *Match* the labeling functions to the data samples
    - Aggregate the labels with different label aggregation techniques
        - Majority Vote
        - FABLE 
- train a logistic regression classifier using weak labels
- train a logistic regresison classifier with SepLL

In [72]:
# necessary imports
import sys

sys.path.append("..")

import logging
import pandas as pd
import numpy as np

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)


from wrench.utils import set_seed
from wrench.endmodel import EndClassifierModel
from wrench._logging import LoggingHandler


from snorkel.utils import probs_to_preds
from utils import load_raw_spam_dataset


#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

logger = logging.getLogger(__name__)

In [4]:
# the path to the folder where our data is stored

path_to_data = "data"

## Data

The dataset we will use for training is Spam Detection YouTube comments dataset 
[3]. 

- The dataset consists of comments that YouTube users left under different videos.
- Each sample is a comment (i.e., a word, a sentence, or a couple of sentences).
- 1,586 train samples, 120 dev samples, 250 test samples
- There are 2 types of samples:
    - HAM: comments relevant to the video (even very simple ones), or
    - SPAM: irrelevant (often trying to advertise something) or inappropriate messages
    
<img src="../img/spam_detection.png" width="800"/>

**NB! Original dataset is manually labeled, but we won't use these gold labels for model training! We will use the dataset as unlabeled one (and label it in a weakly-supervised fasion).** 

Let's first have a look at the dataset.

In [5]:
# load the YouTube dataset

df_train, df_dev, df_test = load_raw_spam_dataset(load_train_labels=True)
# Y_train = df_train["label"].values
# Y_test = df_test["label"].values

In [6]:
df_train[:10]

Unnamed: 0,author,date,text,label,video
0,Alessandro leite,2014-11-05T22:21:36,pls http://www10.vakinha.com.br/VaquinhaE.aspx?e=313327 help me get vip gun cross fire al﻿,1,1
1,Salim Tayara,2014-11-02T14:33:30,"if your like drones, plz subscribe to Kamal Tayara. He takes videos with his drone that are absolutely beautiful.﻿",1,1
2,Phuc Ly,2014-01-20T15:27:47,go here to check the views :3﻿,0,1
3,DropShotSk8r,2014-01-19T04:27:18,"Came here to check the views, goodbye.﻿",0,1
4,css403,2014-11-07T14:25:48,"i am 2,126,492,636 viewer :D﻿",0,1
5,Giang Nguyen,2014-11-06T04:55:41,https://www.facebook.com/teeLaLaLa﻿,1,1
6,Caius Ballad,2014-11-13T00:58:20,imagine if this guy put adsense on with all these views... u could pay ur morgage﻿,0,1
7,Holly,2014-11-06T13:41:30,Follow me on Twitter @mscalifornia95﻿,1,1
8,King uzzy,2014-11-07T23:19:08,Can we reach 3 billion views by December 2014? ﻿,0,1
9,iKap Taz,2014-11-08T13:34:27,Follow 4 Follow @ VaahidMustafic Like 4 Like ﻿,1,1


For each data sample in the original dataset (i.e., a YouTube comment), we know:
- comment's author,
- date when the corresponding comment was left,
- text of the sample,
- gold manual label,
- id of the YouTube video.

In [7]:
# some examples of positive (=non-spam) samples, label id 0

df_train.loc[df_train["label"]==0][:10]

Unnamed: 0,author,date,text,label,video
2,Phuc Ly,2014-01-20T15:27:47,go here to check the views :3﻿,0,1
3,DropShotSk8r,2014-01-19T04:27:18,"Came here to check the views, goodbye.﻿",0,1
4,css403,2014-11-07T14:25:48,"i am 2,126,492,636 viewer :D﻿",0,1
6,Caius Ballad,2014-11-13T00:58:20,imagine if this guy put adsense on with all these views... u could pay ur morgage﻿,0,1
8,King uzzy,2014-11-07T23:19:08,Can we reach 3 billion views by December 2014? ﻿,0,1
10,John Plaatt,2014-11-07T22:22:29,On 0:02 u can see the camera man on his glasses....﻿,0,1
11,Praise Samuel,2014-11-08T11:10:30,2 billion views wow not even baby by justin beibs has that much he doesn't deserve a capitalized name﻿,0,1
16,zhichao wang,2013-11-29T02:13:56,i think about 100 millions of the views come from people who only wanted to check the views﻿,0,1
19,Tedi Foto,2014-11-08T09:33:30,What my gangnam style﻿,0,1
20,Tee Tee,2014-11-07T20:16:51,Loool nice song funny how no one understands (me) and we love it﻿,0,1


In [8]:
# some examples of negative (=spam) samples, label id 0

df_train.loc[df_train["label"]==1][:10]

Unnamed: 0,author,date,text,label,video
0,Alessandro leite,2014-11-05T22:21:36,pls http://www10.vakinha.com.br/VaquinhaE.aspx?e=313327 help me get vip gun cross fire al﻿,1,1
1,Salim Tayara,2014-11-02T14:33:30,"if your like drones, plz subscribe to Kamal Tayara. He takes videos with his drone that are absolutely beautiful.﻿",1,1
5,Giang Nguyen,2014-11-06T04:55:41,https://www.facebook.com/teeLaLaLa﻿,1,1
7,Holly,2014-11-06T13:41:30,Follow me on Twitter @mscalifornia95﻿,1,1
9,iKap Taz,2014-11-08T13:34:27,Follow 4 Follow @ VaahidMustafic Like 4 Like ﻿,1,1
12,Malin Linford,2014-11-05T01:13:43,"Hey guys please check out my new Google+ page it has many funny pictures, FunnyTortsPics https://plus.google.com/112720997191206369631/post﻿",1,1
13,Lone Twistt,2013-11-28T17:34:55,Once you have started reading do not stop. If you do not subscribe to me within one day you and you're entire family will die so if you want to stay alive subscribe right now.﻿,1,1
14,Олег Пась,2014-11-03T23:29:00,Plizz withing my channel ﻿,1,1
15,JD COKE,2014-11-08T02:24:02,"It's so hard, sad :( iThat little child Actor HWANG MINOO dancing very active child is suffering from brain tumor, only 6 month left for him .Hard to believe .. Keep praying everyone for our future superstar. #StrongLittlePsY #Fighting SHARE EVERYONE PRAYING FOR HIM http://ygunited.com/2014/11/08/little-psy-from-the-has-brain-tumor-6-months-left-to-live/ ﻿",1,1
17,Rancy Gaming,2014-11-06T09:41:07,What free gift cards? Go here http://www.swagbucks.com/p/register?rb=13017194﻿,1,1


In [9]:
df_train[["text", "label"]][:20]

Unnamed: 0,text,label
0,pls http://www10.vakinha.com.br/VaquinhaE.aspx?e=313327 help me get vip gun cross fire al﻿,1
1,"if your like drones, plz subscribe to Kamal Tayara. He takes videos with his drone that are absolutely beautiful.﻿",1
2,go here to check the views :3﻿,0
3,"Came here to check the views, goodbye.﻿",0
4,"i am 2,126,492,636 viewer :D﻿",0
5,https://www.facebook.com/teeLaLaLa﻿,1
6,imagine if this guy put adsense on with all these views... u could pay ur morgage﻿,0
7,Follow me on Twitter @mscalifornia95﻿,1
8,Can we reach 3 billion views by December 2014? ﻿,0
9,Follow 4 Follow @ VaahidMustafic Like 4 Like ﻿,1


Now it is time to start weak supervision! So, let's imagin the gold labels disappeared... 

<img src="../img/poof.jpg" width="300"/>

... and here we are: there is some data we want to use for classifier training, but we don't have any labels and capacity/time/money/... for hiring annotators.

But we can label this data with **weak supervision** :)

<img src="../img/rainbow.png" width="500"/>

# Weak Supervision

A brief reminder how weak supervision works:
1. We come up with some heuristic rules and transform these rules into labeling functions.
2. We apply these labeling functions to the data and obtain weak labels.
3. We use this weak labels to train a classifier. 

Let's have a closer look at the training samples we have:

In [10]:
list(df_train.text[100:120])

['how is this shit still relevant \ufeff',
 ' Hey everyone!! I have just started my first YT channel i would be grateful  if some of you peoples could check out my first clip in BF4! and give me  some advice on how my video was and how i could improve it. ALSO be sure to  go check out the about to see what Im all about. Thanks for your time :) .  and to haters... You Hate, I WIN\ufeff',
 'The Funny Thing Is That this song was made in 2009 but it took 2 years to  get to america.\ufeff',
 'Why dafuq is a Korean song so big in the USA. Does that mean we support  Koreans? Last time I checked they wanted to bomb us. \ufeff',
 'People Who Say That "This Song Is Too Old Now, There\'s No Point Of  Listening To It" Suck. Just Stfu And Enjoy The Music. So, Your Mom Is Old  Too But You Still Listen To Her Right?....\ufeff',
 'Follow me on twitter &amp; IG : __killuminati94\ufeff',
 'how does this video have 2,127,322,484 views if there are only 7 million  people on earth?\ufeff',
 'Just coming to

## Task: formulate the rules that could annotate the training samples

The questions that might help you: 

*What patterns are typical for spam YouTube comments? for non-spam comments?*

*What rules might help to distinguish between spam and not-spam YouTube comments?*

*What labeling functions do you think are productive and useful to annotate the YouTube comments?*

Rules: 

1. ...
2. ...
3. ...
4. ...
5. ...
6. ...
7. ...
8. ...
9. ...
10. ...



My examples of rules: 
- "check"/"check out": if there is a collocation "check out" in the comment, most probably this comment is spam (and the comment author is promoting his/her channel)
- "subscribe": same
- "my": same
- ...
 

### What can be a rule?

- Keyword searches: looking for specific words in a sentence
- Pattern matching: looking for specific syntactical patterns
- Third-party models: using an pre-trained model (usually a model for a different task than the one at hand)
- ...
- Crowdworker labels: treating each crowdworker as a black-box function that assigns labels to subsets of the data

### Rules into labeling functions

After we collected some rules, we transform them into labeling functions that could *label* the data sample - that is, assign it to one or another class. 

In [11]:
# an example of LF based on a keyword "check out"

def check_out(x):
    return 1 if "check out" in x.text.lower() else -1

# meaning the sample will be assigned to class 1 (=SPAM) if there is a "check out" expression in the comment, 
# otherwise to class 0 (=non-SPAM)

In [12]:
# an example of LF based on a key word "please"

def check(x):
    return 1 if "please" in x.text.lower() else -1

# meaning the sample will be assigned to class 1 (=SPAM) if there is a "please" expression in the comment, 
# otherwise to class 0 (=non-SPAM)

### Labeling functions we are going to use

In this tutorial, we are going to use the labeling functions created by [Snorkel team](https://github.com/snorkel-team/snorkel-tutorials/blob/master/spam/01_spam_tutorial.ipynb), which are: 


1. keyword **"my"** (to detect spam comments like "my channel", "my video", etc)
2. keyword **"subscribe"** (to detect spam comments that ask users to subscribe to some channel)
3. keyword **"http"** (to detect spam comments that link to other channels)
4. keyword **"please"/"plz"** (to detect spam comments that make requests rather than commenting)
5. keyword **"song"** (to detect non-spam comments that actually talk about the video's content)
6. regex **"check_out"** (to detect spam comments like "check out this channel", etc)
7. **short comment** (non-spam comments are often short, such as 'cool video!')
8. **mentioning specific people** and are **short** (using SpaCy library; non-spam comments usually mention some people)
9. **polarity** (using TextBlob library; if polarity > 0.9, it is most probably a non-spam message)
10. **subjectivity** (using TextBlob library; if subjectivity >= 0.5, it is most probably a non-spam message)

(We are not going into details of the labeling process here now - you will hear more about it from my colleagues later). 

### Processed data

The resulted annotations can be saved in the following format: 

In [13]:
import json
with open("data/youtube/train.json") as train_file:
    train_data = json.load(train_file)
train_data["1"]

{'data': {'text': 'if your like drones, plz subscribe to Kamal Tayara. He takes videos with  his drone that are absolutely beautiful.\ufeff'},
 'label': 1,
 'weak_labels': [-1, 1, -1, 1, -1, -1, -1, -1, -1, 0]}

The structure of the processed data is the following: 
- data.text: the text of the sample
- label: gold label obtained by manual annotation
- weak_labels: the results of annotation by labeling functions. 
    - -1: the corresponding labeling function did not match
    - 0: the labeling function matched and assigned this sample to class 0 (non-spam class in our case)
    - 1: the labeling function matched and assigned this sample to class 1 (spam class in our case)

So, for the sample #1:
(*if your like drones, plz subscribe to Kamal Tayara. He takes videos with  his drone that are absolutely beautiful.\ufeff*)

- labeling functions 1, 3, 5, 6, 7, 8, 9 did not match
- labeling functions 2 (a key word *subscribe*) & 4 (a key word *plz*) matched and assigned this sample to the class 1
- labeling function 10 (subjectivity score > 0.5) matched and assigned this sample to the class 0

**Next step: how to turn these annotations into weak labels to train a classifier with them?**

## Weak labels

There are different *label models* that calculate the weak labels based on labeling functions annotations. In this tutorial, we are going to try two of them: 

- **Majority Vote** (intuitive and straightforward)
- **FABLE** [1] (most recent and well-performing)

For label calculation and model training we will use a weakly supervised framework called [Wrench](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&cad=rja&uact=8&ved=2ahUKEwiRmYabjOGAAxW1h_0HHQt3COQQFnoECA4QAQ&url=https%3A%2F%2Fgithub.com%2FJieyuZ2%2Fwrench&usg=AOvVaw3EWVM0icLVHENbUv51USa_&opi=89978449) [2].

### Wrench dataset

First, we transform our data into a Wrench-specific dataset.

We can encode the data with TF-IDF features... 

In [14]:
# TF-IDF features

from wrench.dataset import load_dataset

train_data_tfidf, valid_data_tfidf, test_data_tfidf = load_dataset(
    path_to_data,     # path to the folder where the dataset is stored
    "youtube",         # name of the dataset
    extract_feature=True,      # we want to encode out data ...
    extract_fn='tfidf'        # ... with TF-IDF features (other predefined options are 'sentence_transformer', 'bert')
)

2023-09-29 07:07:18 - loading data from data/youtube/train.json


  0%|          | 0/1586 [00:00<?, ?it/s]

2023-09-29 07:07:18 - loading data from data/youtube/valid.json


  0%|          | 0/120 [00:00<?, ?it/s]

2023-09-29 07:07:19 - loading data from data/youtube/test.json


  0%|          | 0/250 [00:00<?, ?it/s]

... or with BERT features.

In [15]:
# Bert features

train_data, valid_data, test_data = load_dataset(
    path_to_data,       # path to the folder where the dataset is stored
    "youtube",    # name of the dataset
    extract_feature=True,      # we want to encode out data ...
    extract_fn='bert',        # ... with bert embeddings
    model_name='bert-base-cased',      # the name of the bert model
    cache_name='bert'     # load it from cache if there are cached files 
)

2023-09-29 07:07:20 - loading data from data/youtube/train.json


  0%|          | 0/1586 [00:00<?, ?it/s]

2023-09-29 07:07:20 - loading data from data/youtube/valid.json


  0%|          | 0/120 [00:00<?, ?it/s]

2023-09-29 07:07:20 - loading data from data/youtube/test.json


  0%|          | 0/250 [00:00<?, ?it/s]

2023-09-29 07:07:20 - loading features from data/youtube/train_bert.pkl
2023-09-29 07:07:20 - loading features from data/youtube/valid_bert.pkl
2023-09-29 07:07:20 - loading features from data/youtube/test_bert.pkl


Let's have a look what's inside. 

In [16]:
# the format of the train_data, valida_data, and test_data now is: wrench.dataset.dataset.TextDataset

train_data

<wrench.dataset.dataset.TextDataset at 0x7f85107da670>

In [17]:
# how many classes are there in the dataset?

train_data.n_class

2

In [18]:
# how many labeling functions are there in the dataset?

train_data.n_lf

10

In [19]:
# what is the class_id to class correspondence?

train_data.id2label

{0: 'HAM', 1: 'SPAM'}

In [20]:
# how do the samples look like?

train_data.examples[:10]

[{'text': 'pls http://www10.vakinha.com.br/VaquinhaE.aspx?e=313327 help me get vip gun  cross fire al\ufeff'},
 {'text': 'if your like drones, plz subscribe to Kamal Tayara. He takes videos with  his drone that are absolutely beautiful.\ufeff'},
 {'text': 'go here to check the views :3\ufeff'},
 {'text': 'Came here to check the views, goodbye.\ufeff'},
 {'text': 'i am 2,126,492,636 viewer :D\ufeff'},
 {'text': 'https://www.facebook.com/teeLaLaLa\ufeff'},
 {'text': 'imagine if this guy put adsense on with all these views... u could pay ur  morgage\ufeff'},
 {'text': 'Follow me on Twitter @mscalifornia95\ufeff'},
 {'text': 'Can we reach 3 billion views by December 2014? \ufeff'},
 {'text': 'Follow 4 Follow                           @ VaahidMustafic Like 4 Like \ufeff'}]

In [21]:
# how do the encoded samples look like?

print(type(train_data.features))
train_data.features[:10]

<class 'numpy.ndarray'>


array([[-0.7608848 ,  0.4292751 ,  0.99990165, ...,  0.99997175,
        -0.7515034 ,  0.9910971 ],
       [-0.7922811 ,  0.47435865,  0.9999021 , ...,  0.9999663 ,
        -0.7555875 ,  0.9860616 ],
       [-0.71461093,  0.4088342 ,  0.9997614 , ...,  0.9999327 ,
        -0.5964781 ,  0.98014975],
       ...,
       [-0.7311086 ,  0.40834075,  0.9998686 , ...,  0.99996996,
        -0.6895463 ,  0.9860986 ],
       [-0.606766  ,  0.48277682,  0.99963987, ...,  0.9999085 ,
        -0.89470583,  0.97945964],
       [-0.7177589 ,  0.49044982,  0.9998655 , ...,  0.999958  ,
        -0.82362586,  0.9884726 ]], dtype=float32)

In [22]:
# what are the weak annotations produced by labeling functions?

train_data.weak_labels[3]

[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]

### Majority Vote

The simplest and most straightforward method to calculate labels from the noisy annotations is **majority voting** - a decision-making method where the option with the most votes is chosen. It's like asking a group of people to pick a movie, and the one that gets the most hands raised wins. 

In our case, each labeling function produces a *vote*; the most voted class is selected as a sample label. All ties are broken randomly.


# Task:  write your own majority vote function
- Input: the weak annotations produced by labeling functions (stored in weak_labels field of wrench dataset objects)
- Output: labels

Before you start programming, think about possible bottlenecks: 
- what if a sample obtains equal amount of votes for some class?
- what if there are no votes for a sample?

In [76]:
train_data.weak_labels[:10]

[[-1, -1, 1, -1, -1, -1, -1, -1, -1, -1],
 [-1, 1, -1, 1, -1, -1, -1, -1, -1, 0],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, 1, -1, -1, -1, 0, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]

In [77]:
# todo
def majority_vote(weak_annotations):
    # calculate labels with majority vote
    # output should be a numpy array of shape (number of training samples) x 1
    labels = []
    # todo
    return np.array(labels)

labels_mv = majority_vote(train_data.weak_labels)

A ready solution to aggregate the weak labels with majority vote is already included to the Wrench framework:`MajorityVoting` label model.

In [79]:
# initialize and fit the majority vote label model from the Wrench framework

from wrench.labelmodel import MajorityVoting

label_model = MajorityVoting()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)

In [80]:
# calculate weak labels 

soft_label_mv = label_model.predict_proba(train_data)    # soft label as probabilities across all classes
hard_label_mv = probs_to_preds(soft_label_mv)               # hard labels as the most probable classes 

In [81]:
hard_label_mv

array([1, 1, 0, ..., 1, 1, 1])

Let's look at the first 10 sentences, their weak annotations, and the weak labels obtained with majority voting. 

In [59]:
train_data.examples[:10]

[{'text': 'pls http://www10.vakinha.com.br/VaquinhaE.aspx?e=313327 help me get vip gun  cross fire al\ufeff'},
 {'text': 'if your like drones, plz subscribe to Kamal Tayara. He takes videos with  his drone that are absolutely beautiful.\ufeff'},
 {'text': 'go here to check the views :3\ufeff'},
 {'text': 'Came here to check the views, goodbye.\ufeff'},
 {'text': 'i am 2,126,492,636 viewer :D\ufeff'},
 {'text': 'https://www.facebook.com/teeLaLaLa\ufeff'},
 {'text': 'imagine if this guy put adsense on with all these views... u could pay ur  morgage\ufeff'},
 {'text': 'Follow me on Twitter @mscalifornia95\ufeff'},
 {'text': 'Can we reach 3 billion views by December 2014? \ufeff'},
 {'text': 'Follow 4 Follow                           @ VaahidMustafic Like 4 Like \ufeff'}]

In [60]:
train_data.weak_labels[:10]

[[-1, -1, 1, -1, -1, -1, -1, -1, -1, -1],
 [-1, 1, -1, 1, -1, -1, -1, -1, -1, 0],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, 1, -1, -1, -1, 0, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]

In [61]:
soft_label_mv[:10]

array([[0.        , 1.        ],
       [0.33333333, 0.66666667],
       [0.5       , 0.5       ],
       [0.5       , 0.5       ],
       [0.5       , 0.5       ],
       [0.5       , 0.5       ],
       [0.5       , 0.5       ],
       [0.5       , 0.5       ],
       [0.5       , 0.5       ],
       [0.5       , 0.5       ]])

In [28]:
hard_label_mv[:10]

array([1, 1, 0, 1, 0, 0, 0, 0, 1, 1])

### FABLE 

Fable [1] is a label model where noisy labels are inferred not only based on the labeling functions' votes, but also using the instance features. 

In [52]:
# initialize and apply the fable model
from wrench.labelmodel import Fable

label_model = Fable(kernel_function=None, num_groups=10)
_ = label_model.fit(dataset_train=train_data, dataset_valid=valid_data)

NaN values included: []


  0%|▏                                                                                                             | 2/1000 [00:18<2:36:20,  9.40s/iter]

stop





In [53]:
# calculate labels
soft_label_fable = label_model.predict_proba(train_data)
hard_label_fable = probs_to_preds(soft_label_fable)

  0%|▏                                                                                                             | 2/1000 [00:18<2:32:54,  9.19s/iter]

stop





In [54]:
soft_label_fable[:10]

array([[0.54772653, 0.45227347],
       [0.04384498, 0.95615502],
       [0.56616534, 0.43383466],
       [0.55068925, 0.44931075],
       [0.55410647, 0.44589353],
       [0.87257854, 0.12742146],
       [0.53917847, 0.46082153],
       [0.5582726 , 0.4417274 ],
       [0.51587386, 0.48412614],
       [0.54213748, 0.45786252]])

In [55]:
hard_label_fable[:10]

array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0])

## Classifier training

In [33]:
batch_size = 32
test_batch_size = 32
lr = 0.01

Train a classifier with majorty vote hard labels.

In [34]:
set_seed(42)

# initialize a classifier
model = EndClassifierModel(
    batch_size=batch_size, test_batch_size=test_batch_size
)

# fit it on the training data + majority vote hard labels
model.fit(
    dataset_train=train_data, 
    y_train=hard_label_mv, 
    dataset_valid=valid_data, 
    verbose=False
)

# test on the test set
model.test(dataset=test_data, metric_fn="acc")

0.908

Train a classifier with FABLE hard labels.

In [56]:
set_seed(42)

# initialize a classifier
model = EndClassifierModel(
    batch_size=batch_size, test_batch_size=test_batch_size
)

# fit it 
model.fit(
    dataset_train=train_data, 
    y_train=hard_label_fable, 
    dataset_valid=valid_data,
    verbose=False
)

# test on the test set
model.test(dataset=test_data, metric_fn="acc")

0.836

## End-2-End training with SepLL

In the following, we use a state-of-the-art method called SepLL [4] to train a classifier with weak labels. During training, LF matches are the only training signal, and prediction is then later made from a latent state.

In [36]:
from wrench.classification.sepll import SepLL

set_seed(42)

bert_model_name = 'roberta-base'

#### Initialize SepLL
model = SepLL(
    batch_size=batch_size,
    test_batch_size=test_batch_size,
    backbone='MLP',
    backbone_model_name=bert_model_name,
    # 
    # SepLL specific
    add_unlabeled=False,
    class_noise=0.0,
    lf_l2_regularization=0.05,
)


model.fit(
    dataset_train=train_data,
    dataset_valid=valid_data,
    verbose=True
)

acc = model.test(test_data, 'acc')

logger.info(f'SepLL test acc: {acc}')

2023-09-28 23:35:08 - 
{
    "batch_size": 32,
    "real_batch_size": 16,
    "test_batch_size": 32,
    "n_steps": 10000,
    "grad_norm": -1,
    "use_lr_scheduler": false,
    "binary_mode": false
}
{
    "name": "Adam",
    "paras": {
        "lr": 0.001,
        "weight_decay": 0.0
    }
}
{
    "name": "MLP",
    "paras": {
        "hidden_size": 100,
        "dropout": 0.0,
        "model_name": "roberta-base"
    }
}
{
    "name": "MajorityVoting",
    "paras": {}
}



  labeled_dataset.weak_labels / labeled_dataset.weak_labels.sum(axis=1, keepdims=True)
  dataset_valid.weak_labels / dataset_valid.weak_labels.sum(axis=1, keepdims=True)


[TRAIN] SepLL:   0%|                                                                                          …

2023-09-28 23:35:38 - [INFO] early stop @ step 3000!
2023-09-28 23:35:38 - SepLL test acc: 0.896


### GPU training

In case your environment has a GPU available, it is also possible to make use of the full strength of SepLL. 

In [37]:
from wrench.classification.sepll import SepLL

set_seed(42)

batch_size=16
bert_model_name = 'roberta-base'

#### Initialize SepLL
model = SepLL(
    batch_size=batch_size,
    real_batch_size=batch_size,
    test_batch_size=test_batch_size,
    # BERT specific parameters
    backbone='BERT',
    backbone_model_name=bert_model_name,
    optimizer='Adam',
    optimizer_lr=5e-5,
    optimizer_weight_decay=0.0,
    
    # SepLL specific
    add_unlabeled=False,
    class_noise=0.0,
    lf_l2_regularization=0.5,
)


model.fit(
    dataset_train=train_data,
    dataset_valid=valid_data,
    metric='acc',
    verbose=True
)

2023-09-28 23:35:41 - 
{
    "batch_size": 16,
    "real_batch_size": 16,
    "test_batch_size": 32,
    "n_steps": 10000,
    "grad_norm": -1,
    "use_lr_scheduler": false,
    "binary_mode": false
}
{
    "name": "Adam",
    "paras": {
        "lr": 5e-05,
        "weight_decay": 0.0
    }
}
{
    "name": "BERT",
    "paras": {
        "model_name": "roberta-base",
        "max_tokens": 512,
        "fine_tune_layers": -1
    }
}
{
    "name": "MajorityVoting",
    "paras": {}
}



  labeled_dataset.weak_labels / labeled_dataset.weak_labels.sum(axis=1, keepdims=True)
  dataset_valid.weak_labels / dataset_valid.weak_labels.sum(axis=1, keepdims=True)
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[TRAIN] SepLL:   0%|                                                                                          …

2023-09-28 23:41:52 - KeyboardInterrupt! do not terminate the process in case need to save the best model


AttributeError: 'NoneType' object has no attribute 'copy'

In [None]:
acc = model.test(test_data, 'acc')

logger.info(f'SepLL test acc: {acc}')

# References

1. Zhang et al. 2023. Leveraging Instance Features for Label Aggregation in Programmatic Weak Supervision. https://arxiv.org/abs/2210.02724 
2. Zhang et al. 2021 WRENCH: A Comprehensive Benchmark for Weak Supervision. https://arxiv.org/abs/2109.11377
3. Alberto TC et al.  2015. Tubespam: Comment Spam Filtering on Youtube. https://ieeexplore.ieee.org/document/7424299
4. Stephan et al. 2022. SepLL: Separating Latent Class Labels from Weak Supervision Noise. https://arxiv.org/abs/2210.13898
