# Imports

In [None]:
from .semantic_search import *
from finetune_sbert import SBERTTuner
from .data.newsgroups_data import NewsGroupsDataset
from sentence_transformers import losses
from finetune_distilbert import DistilBERTTuner

# Load data

In [2]:
# Set the categories you want to include

newsgroups_params = {
    'categories': ['alt.atheism',
                    'comp.graphics',
                    'comp.os.ms-windows.misc',
                    'comp.sys.ibm.pc.hardware',
                    'comp.sys.mac.hardware',
                    'comp.windows.x',
                    'misc.forsale',
                    'rec.autos',
                    'rec.motorcycles',
                    'rec.sport.baseball',
                    'rec.sport.hockey',
                    'sci.crypt',
                    'sci.electronics',
                    'sci.med',
                    'sci.space',
                    'soc.religion.christian',
                    'talk.politics.guns',
                    'talk.politics.mideast',
                    'talk.politics.misc',
                    'talk.religion.misc'],
    'holdout_classes': ['talk.politics.guns',
                        'talk.politics.mideast',
                        'talk.politics.misc']
}



In [3]:
dataset = NewsGroupsDataset(**newsgroups_params)

# SBERT Evaluations

In [None]:
sbert_tuner = SBERTTuner(dataset=dataset, dataset_name='20newsgroups')

In [None]:
sbert_tuner.fine_tune(batch_size=16, epochs=30, learning_rate=1e-06, weight_decay=0, max_grad_norm=0.5, use_student_teacher=True, kl_alpha=0.85)

In [6]:
run_text_classification_pipeline(dataset, model_name='all-MiniLM-L6-v2')


Trained with classes: alt.atheism, comp.graphics, comp.os.ms-windows.misc, comp.sys.ibm.pc.hardware, comp.sys.mac.hardware, comp.windows.x, misc.forsale, rec.autos, rec.motorcycles, rec.sport.baseball, rec.sport.hockey, sci.crypt, sci.electronics, sci.med, sci.space, soc.religion.christian, talk.religion.misc
Holding out classes: talk.politics.guns, talk.politics.mideast, talk.politics.misc


Batches:   0%|          | 0/365 [00:00<?, ?it/s]

FAISS index created w/ dimension: 384


Batches:   0%|          | 0/102 [00:00<?, ?it/s]


== Regular Test Evaluation Results (KNN) ==
Accuracy:    0.7529
F1-Score:    0.7544
Precision:   0.7641
Recall:      0.7529
MRR:         0.8280
Top-5 Acc: 0.9088

Full Classification Report:
                          precision    recall  f1-score   support

             alt.atheism       0.62      0.70      0.65       160
           comp.graphics       0.68      0.78      0.73       195
 comp.os.ms-windows.misc       0.49      0.71      0.58       197
comp.sys.ibm.pc.hardware       0.65      0.67      0.66       196
   comp.sys.mac.hardware       0.66      0.63      0.65       193
          comp.windows.x       0.81      0.81      0.81       198
            misc.forsale       0.71      0.78      0.74       195
               rec.autos       0.77      0.74      0.75       198
         rec.motorcycles       0.88      0.78      0.82       199
      rec.sport.baseball       0.90      0.84      0.87       199
        rec.sport.hockey       0.92      0.85      0.89       200
               

Batches:   0%|          | 0/102 [00:00<?, ?it/s]


== Classifier Evaluation Metrics ==
Average Loss: 2.7823
Perplexity:   16.1558
Few-Shot Learning: Adding Holdout Classes to Index


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/183 [00:00<?, ?it/s]


== Holdout (Few-Shot) Evaluation Results (KNN) ==
Accuracy:    0.4374
F1-Score:    0.3724
Precision:   0.4584
Recall:      0.4374
MRR:         0.5260
Top-5 Acc: 0.6116

Full Classification Report:
                          precision    recall  f1-score   support

             alt.atheism       0.14      0.69      0.24       160
           comp.graphics       0.63      0.78      0.70       195
 comp.os.ms-windows.misc       0.37      0.71      0.49       197
comp.sys.ibm.pc.hardware       0.63      0.67      0.65       196
   comp.sys.mac.hardware       0.71      0.62      0.66       193
          comp.windows.x       0.77      0.81      0.79       198
            misc.forsale       0.64      0.78      0.71       195
               rec.autos       0.38      0.74      0.50       198
         rec.motorcycles       0.59      0.78      0.67       199
      rec.sport.baseball       0.75      0.84      0.79       199
        rec.sport.hockey       0.75      0.85      0.80       200
         

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [7]:
run_text_classification_pipeline(dataset, model_name='models/tuned_sbert/20newsgroups')

Trained with classes: alt.atheism, comp.graphics, comp.os.ms-windows.misc, comp.sys.ibm.pc.hardware, comp.sys.mac.hardware, comp.windows.x, misc.forsale, rec.autos, rec.motorcycles, rec.sport.baseball, rec.sport.hockey, sci.crypt, sci.electronics, sci.med, sci.space, soc.religion.christian, talk.religion.misc
Holding out classes: talk.politics.guns, talk.politics.mideast, talk.politics.misc


Batches:   0%|          | 0/365 [00:00<?, ?it/s]

FAISS index created w/ dimension: 384


Batches:   0%|          | 0/102 [00:00<?, ?it/s]


== Regular Test Evaluation Results (KNN) ==
Accuracy:    0.7643
F1-Score:    0.7664
Precision:   0.7769
Recall:      0.7643
MRR:         0.8286
Top-5 Acc: 0.8980

Full Classification Report:
                          precision    recall  f1-score   support

             alt.atheism       0.59      0.69      0.63       160
           comp.graphics       0.69      0.78      0.73       195
 comp.os.ms-windows.misc       0.49      0.75      0.59       197
comp.sys.ibm.pc.hardware       0.70      0.71      0.70       196
   comp.sys.mac.hardware       0.66      0.68      0.67       193
          comp.windows.x       0.83      0.84      0.84       198
            misc.forsale       0.79      0.83      0.81       195
               rec.autos       0.81      0.75      0.78       198
         rec.motorcycles       0.87      0.73      0.79       199
      rec.sport.baseball       0.92      0.86      0.89       199
        rec.sport.hockey       0.95      0.84      0.89       200
               

Batches:   0%|          | 0/102 [00:00<?, ?it/s]


== Classifier Evaluation Metrics ==
Average Loss: 2.7729
Perplexity:   16.0045
Few-Shot Learning: Adding Holdout Classes to Index


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/183 [00:00<?, ?it/s]


== Holdout (Few-Shot) Evaluation Results (KNN) ==
Accuracy:    0.4410
F1-Score:    0.3786
Precision:   0.6224
Recall:      0.4410
MRR:         0.5242
Top-5 Acc: 0.6020

Full Classification Report:
                          precision    recall  f1-score   support

             alt.atheism       0.12      0.69      0.20       160
           comp.graphics       0.66      0.78      0.71       195
 comp.os.ms-windows.misc       0.33      0.75      0.46       197
comp.sys.ibm.pc.hardware       0.68      0.71      0.69       196
   comp.sys.mac.hardware       0.62      0.68      0.65       193
          comp.windows.x       0.79      0.84      0.81       198
            misc.forsale       0.73      0.83      0.77       195
               rec.autos       0.42      0.75      0.54       198
         rec.motorcycles       0.59      0.73      0.65       199
      rec.sport.baseball       0.82      0.86      0.84       199
        rec.sport.hockey       0.86      0.84      0.85       200
         

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [8]:
run_text_classification_pipeline(dataset, model_name='distilbert-base-nli-stsb-mean-tokens')

Trained with classes: alt.atheism, comp.graphics, comp.os.ms-windows.misc, comp.sys.ibm.pc.hardware, comp.sys.mac.hardware, comp.windows.x, misc.forsale, rec.autos, rec.motorcycles, rec.sport.baseball, rec.sport.hockey, sci.crypt, sci.electronics, sci.med, sci.space, soc.religion.christian, talk.religion.misc
Holding out classes: talk.politics.guns, talk.politics.mideast, talk.politics.misc


Batches:   0%|          | 0/365 [00:00<?, ?it/s]

FAISS index created w/ dimension: 768


Batches:   0%|          | 0/102 [00:00<?, ?it/s]


== Regular Test Evaluation Results (KNN) ==
Accuracy:    0.5898
F1-Score:    0.5929
Precision:   0.6114
Recall:      0.5898
MRR:         0.6944
Top-5 Acc: 0.8234

Full Classification Report:
                          precision    recall  f1-score   support

             alt.atheism       0.42      0.55      0.48       160
           comp.graphics       0.39      0.58      0.47       195
 comp.os.ms-windows.misc       0.36      0.56      0.44       197
comp.sys.ibm.pc.hardware       0.43      0.42      0.42       196
   comp.sys.mac.hardware       0.43      0.46      0.45       193
          comp.windows.x       0.56      0.61      0.58       198
            misc.forsale       0.68      0.67      0.68       195
               rec.autos       0.72      0.57      0.64       198
         rec.motorcycles       0.72      0.60      0.66       199
      rec.sport.baseball       0.82      0.73      0.77       199
        rec.sport.hockey       0.83      0.76      0.79       200
               

Batches:   0%|          | 0/102 [00:00<?, ?it/s]


== Classifier Evaluation Metrics ==
Average Loss: 2.7744
Perplexity:   16.0294
Few-Shot Learning: Adding Holdout Classes to Index


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/183 [00:00<?, ?it/s]


== Holdout (Few-Shot) Evaluation Results (KNN) ==
Accuracy:    0.3264
F1-Score:    0.2494
Precision:   0.2119
Recall:      0.3264
MRR:         0.3913
Top-5 Acc: 0.4705

Full Classification Report:
                          precision    recall  f1-score   support

             alt.atheism       0.11      0.55      0.18       160
           comp.graphics       0.30      0.58      0.40       195
 comp.os.ms-windows.misc       0.25      0.56      0.34       197
comp.sys.ibm.pc.hardware       0.37      0.42      0.39       196
   comp.sys.mac.hardware       0.42      0.45      0.43       193
          comp.windows.x       0.48      0.61      0.53       198
            misc.forsale       0.53      0.67      0.59       195
               rec.autos       0.31      0.57      0.40       198
         rec.motorcycles       0.45      0.60      0.51       199
      rec.sport.baseball       0.42      0.73      0.53       199
        rec.sport.hockey       0.57      0.76      0.65       200
         

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# Distilbert evaluations

In [None]:
DB_tuner = DistilBERTTuner(dataset, '20newsgroups')

In [None]:
DB_tuner.fine_tune(epochs=30)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,2.1017
1000,1.4152
1500,1.1638
2000,0.9994
2500,0.9026
3000,0.7742
3500,0.6822
4000,0.6567
4500,0.6049
5000,0.5341


Model saved at models/tuned_distilbert/20newsgroups


In [11]:
run_text_classification_pipeline(dataset, model_name='models/tuned_distilbert/20newsgroups')

Trained with classes: alt.atheism, comp.graphics, comp.os.ms-windows.misc, comp.sys.ibm.pc.hardware, comp.sys.mac.hardware, comp.windows.x, misc.forsale, rec.autos, rec.motorcycles, rec.sport.baseball, rec.sport.hockey, sci.crypt, sci.electronics, sci.med, sci.space, soc.religion.christian, talk.religion.misc
Holding out classes: talk.politics.guns, talk.politics.mideast, talk.politics.misc


Batches:   0%|          | 0/365 [00:00<?, ?it/s]

FAISS index created w/ dimension: 768


Batches:   0%|          | 0/102 [00:00<?, ?it/s]


== Regular Test Evaluation Results (KNN) ==
Accuracy:    0.7260
F1-Score:    0.7291
Precision:   0.7354
Recall:      0.7260
MRR:         0.7714
Top-5 Acc: 0.8259

Full Classification Report:
                          precision    recall  f1-score   support

             alt.atheism       0.58      0.61      0.59       160
           comp.graphics       0.69      0.72      0.71       195
 comp.os.ms-windows.misc       0.47      0.68      0.56       197
comp.sys.ibm.pc.hardware       0.62      0.62      0.62       196
   comp.sys.mac.hardware       0.61      0.62      0.61       193
          comp.windows.x       0.82      0.84      0.83       198
            misc.forsale       0.78      0.75      0.77       195
               rec.autos       0.77      0.75      0.76       198
         rec.motorcycles       0.81      0.77      0.79       199
      rec.sport.baseball       0.93      0.86      0.90       199
        rec.sport.hockey       0.91      0.85      0.88       200
               

Batches:   0%|          | 0/102 [00:00<?, ?it/s]


== Classifier Evaluation Metrics ==
Average Loss: 2.7357
Perplexity:   15.4199
Few-Shot Learning: Adding Holdout Classes to Index


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/183 [00:00<?, ?it/s]


== Holdout (Few-Shot) Evaluation Results (KNN) ==
Accuracy:    0.4058
F1-Score:    0.3415
Precision:   0.5935
Recall:      0.4058
MRR:         0.4514
Top-5 Acc: 0.5045

Full Classification Report:
                          precision    recall  f1-score   support

             alt.atheism       0.10      0.61      0.17       160
           comp.graphics       0.62      0.72      0.67       195
 comp.os.ms-windows.misc       0.35      0.68      0.46       197
comp.sys.ibm.pc.hardware       0.60      0.62      0.61       196
   comp.sys.mac.hardware       0.64      0.61      0.63       193
          comp.windows.x       0.79      0.84      0.81       198
            misc.forsale       0.70      0.75      0.72       195
               rec.autos       0.54      0.75      0.63       198
         rec.motorcycles       0.40      0.77      0.53       199
      rec.sport.baseball       0.83      0.86      0.84       199
        rec.sport.hockey       0.78      0.85      0.81       200
         

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
