In [1]:
import pandas as pd
import numpy as np
import os

import sys
sys.path.append('/home/ec2-user/SageMaker/mariano/repositories/tdmstudio-high-recall-information-retrieval-system/')

In [2]:
from utils.newsgroup20.dataset import Dataset20NG,DataItem20NG
from utils.newsgroup20.scal import SCAL20NG
from sklearn.metrics import f1_score, precision_score, recall_score
# unlabeled = Dataset20NG.get_20newsgroup_unlabeled_collection()

In [3]:
# PARAMS
first_round_ni=5
first_round_Ni=2500
first_round_tg=0.9

second_round_ni=5
second_round_Ni=2500
second_round_tg=0.8

single_round_ni=5
single_round_Ni=2500
single_round_tg=0.8

# Two step SCAL

In [4]:



# VECTOR REPRESENTATIONS AND GROUND-TRUTH (ORACLE)
representations = Dataset20NG.get_20newsgroup_representations(type_="bow") # CHANGE <<<<<<<<<
oracle = Dataset20NG.get_20newsgroup_oracle(category='rec.motorcycles')

# UNLABELED 
unlabeled = Dataset20NG.get_20newsgroup_unlabeled_collection()
total_instance_count=len(unlabeled)


# LABELED
relevants = [item for item in unlabeled if oracle[item.id_]==DataItem20NG.RELEVANT_LABEL]
rng = np.random.default_rng(2022)
labeled = list(rng.choice(relevants, size=1))
for item in labeled:
    item.set_relevant()
labeled_ids = {item.id_ for item in labeled}

# REMOVING NEWLY LABELED FROM UNLABELED
unlabeled = [item for item in unlabeled if not item.id_ in labeled_ids]

assert len(unlabeled)==len(Dataset20NG.get_20newsgroup_unlabeled_collection())-len(labeled)

scal_model = SCAL20NG(session_name='two round scal',
                      labeled_collection=labeled,
                      unlabeled_collection=unlabeled,
                      batch_size_cap=first_round_ni,
                      random_sample_size=first_round_Ni,
                      target_recall=first_round_tg,
                      ranking_function='relevance',
                      item_representation=representations,
                      oracle=oracle,
                      model_type='logreg',
                      seed=123456)

results = scal_model.run()
new_labeled = scal_model.labeled_collection
labeled_ids=set([item.id_ for item in new_labeled])

yhat = scal_model.models[-1].predict(unlabeled,item_representation=representations)
suggestions = [item for item,score in zip(unlabeled,yhat) if score>scal_model.threshold if not item.id_ in labeled_ids]
new_unlabeled = [item for item in suggestions]





effort=scal_model._total_effort()
print(f'effort={effort}')

# SECOND ROUND LABELED
assert scal_model.models[-1].trained

print(f'new labeled={len(new_labeled)}')
print(f'new unlabeled={len(new_unlabeled)}')



scal_model = SCAL20NG(session_name='two round scal(b)',
                      labeled_collection=new_labeled,
                      unlabeled_collection=new_unlabeled,
                      batch_size_cap=second_round_ni,
                      random_sample_size=second_round_Ni,
                      target_recall=second_round_tg,
                      ranking_function='relevance',
                      item_representation=representations,
                      oracle=oracle,
                      model_type='logreg',
                      seed=123456)

results = scal_model.run()


labeled_ids=set([item.id_ for item in scal_model.labeled_collection])

yhat = scal_model.models[-1].predict(unlabeled,item_representation=representations)


final_suggestions = [item for item,score in zip(unlabeled,yhat) if score>scal_model.threshold if not item.id_ in labeled_ids]
final_suggestions_ids=[item.id_ for item in final_suggestions]


final_unlabeled = [item for item in Dataset20NG.get_20newsgroup_unlabeled_collection() if not item.id_ in labeled_ids]
print(f'final_unlabeled={len(final_unlabeled)}')
print(f'final_labeled=  {len(scal_model.labeled_collection)}')

ytrue = [oracle[elem.id_]=='R' for elem in final_unlabeled]
ypred = [elem.id_ in final_suggestions_ids for elem in final_unlabeled]

print(f'Precision= {precision_score(ytrue,ypred):4.3f}')
print(f'Recall=    {recall_score(ytrue,ypred):4.3f}')
print(f'F1-score=  {f1_score(ytrue,ypred):4.3f}')

representations file found, loading pickle (/home/ec2-user/SageMaker/mariano/datasets/20news-18828/representations/20NG_representations_bow.pickle) ... 
j= 1 - B=    2 - b= 1 - len(labeled)=     2 - len(unlabeled)=  2499 - precision=0.000 - Rhat=  0.00 - tj=0.0118
j= 2 - B=    3 - b= 2 - len(labeled)=     4 - len(unlabeled)=  2497 - precision=1.000 - Rhat=  2.00 - tj=0.0111
j= 3 - B=    4 - b= 3 - len(labeled)=     7 - len(unlabeled)=  2494 - precision=1.000 - Rhat=  5.00 - tj=0.0317
j= 4 - B=    5 - b= 4 - len(labeled)=    11 - len(unlabeled)=  2490 - precision=1.000 - Rhat=  9.00 - tj=0.0537
j= 5 - B=    6 - b= 5 - len(labeled)=    16 - len(unlabeled)=  2485 - precision=1.000 - Rhat= 14.00 - tj=0.0713
j= 6 - B=    7 - b= 5 - len(labeled)=    21 - len(unlabeled)=  2479 - precision=1.000 - Rhat= 20.00 - tj=00.089
j= 7 - B=    8 - b= 5 - len(labeled)=    26 - len(unlabeled)=  2472 - precision=1.000 - Rhat= 27.00 - tj=00.109
j= 8 - B=    9 - b= 5 - len(labeled)=    31 - len(unlabeled)=  

j=30 - B=  104 - b= 5 - len(labeled)=   331 - len(unlabeled)=  1614 - precision=0.000 - Rhat=232.80 - tj=00.168
j=31 - B=  115 - b= 5 - len(labeled)=   336 - len(unlabeled)=  1510 - precision=0.000 - Rhat=232.80 - tj=00.168
j=32 - B=  127 - b= 5 - len(labeled)=   341 - len(unlabeled)=  1395 - precision=0.000 - Rhat=232.80 - tj=00.166
j=33 - B=  140 - b= 5 - len(labeled)=   346 - len(unlabeled)=  1268 - precision=0.000 - Rhat=232.80 - tj=00.159
j=34 - B=  154 - b= 5 - len(labeled)=   351 - len(unlabeled)=  1128 - precision=0.000 - Rhat=232.80 - tj=00.153
j=35 - B=  170 - b= 5 - len(labeled)=   356 - len(unlabeled)=   974 - precision=0.000 - Rhat=232.80 - tj=00.146
j=36 - B=  187 - b= 5 - len(labeled)=   361 - len(unlabeled)=   804 - precision=0.000 - Rhat=232.80 - tj=00.149
j=37 - B=  206 - b= 5 - len(labeled)=   366 - len(unlabeled)=   617 - precision=0.000 - Rhat=232.80 - tj=00.142
j=38 - B=  227 - b= 5 - len(labeled)=   371 - len(unlabeled)=   411 - precision=0.000 - Rhat=232.80 - tj

In [None]:
results

In [None]:
new_unlabeled

# ONE STEP SCAL

In [49]:
representations = Dataset20NG.get_20newsgroup_representations(type_="bow") # CHANGE <<<<<<<<<


unlabeled = Dataset20NG.get_20newsgroup_unlabeled_collection()
oracle = Dataset20NG.get_20newsgroup_oracle(category='rec.motorcycles')
# 
relevants = [item for item in unlabeled if oracle[item.id_]==DataItem20NG.RELEVANT_LABEL]

rng = np.random.default_rng(2022)
labeled = list(rng.choice(relevants, size=1))
for item in labeled:
    item.set_relevant()

labeled_ids = {item.id_ for item in labeled}
unlabeled = [item for item in unlabeled if not item.id_ in labeled_ids]

print(f'len(labeled_ids)={len(labeled_ids)}')
print(f'len(unlabeled)={len(unlabeled)}')


scal_model = SCAL20NG(session_name='two round scal',
                      labeled_collection=labeled,
                      unlabeled_collection=unlabeled,
                      batch_size_cap=single_round_ni,
                      random_sample_size=single_round_Ni,
                      target_recall=single_round_tg,
                      ranking_function='relevance',
                      item_representation=representations,
                      oracle=oracle,
                      model_type='logreg',
                      seed=123456)


results = scal_model.run()

labeled_ids=set([item.id_ for item in scal_model.labeled_collection])

yhat = scal_model.models[-1].predict(unlabeled,item_representation=representations)


final_suggestions = [item for item,score in zip(unlabeled,yhat) if score>scal_model.threshold if not item.id_ in labeled_ids]
final_suggestions_ids=[item.id_ for item in final_suggestions]


final_unlabeled = [item for item in Dataset20NG.get_20newsgroup_unlabeled_collection() if not item.id_ in labeled_ids]
print(f'final_unlabeled={len(final_unlabeled)}')
print(f'final_labeled=  {len(scal_model.labeled_collection)}')

ytrue = [oracle[elem.id_]=='R' for elem in final_unlabeled]
ypred = [elem.id_ in final_suggestions_ids for elem in final_unlabeled]

print(f'Precision= {precision_score(ytrue,ypred):4.3f}')
print(f'Recall=    {recall_score(ytrue,ypred):4.3f}')
print(f'F1-score=  {f1_score(ytrue,ypred):4.3f}')

representations file found, loading pickle (/home/ec2-user/SageMaker/mariano/datasets/20news-18828/representations/20NG_representations_bow.pickle) ... 
len(labeled_ids)=1
len(unlabeled)=18827
j= 1 - B=    2 - b= 1 - len(labeled)=     2 - len(unlabeled)=  2499 - precision=0.000 - Rhat=  0.00 - tj=0.0118
j= 2 - B=    3 - b= 2 - len(labeled)=     4 - len(unlabeled)=  2497 - precision=1.000 - Rhat=  2.00 - tj=0.0111
j= 3 - B=    4 - b= 3 - len(labeled)=     7 - len(unlabeled)=  2494 - precision=1.000 - Rhat=  5.00 - tj=0.0317
j= 4 - B=    5 - b= 4 - len(labeled)=    11 - len(unlabeled)=  2490 - precision=1.000 - Rhat=  9.00 - tj=0.0537
j= 5 - B=    6 - b= 5 - len(labeled)=    16 - len(unlabeled)=  2485 - precision=1.000 - Rhat= 14.00 - tj=0.0713
j= 6 - B=    7 - b= 5 - len(labeled)=    21 - len(unlabeled)=  2479 - precision=1.000 - Rhat= 20.00 - tj=00.089
j= 7 - B=    8 - b= 5 - len(labeled)=    26 - len(unlabeled)=  2472 - precision=1.000 - Rhat= 27.00 - tj=00.109
j= 8 - B=    9 - b= 5 -