In [2]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, AutoModelForSequenceClassification
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
newsgroups = fetch_20newsgroups(data_home="../data",subset='all')
X, y = newsgroups.data, newsgroups.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [4]:
X_train[0]

"From: mahan@TGV.COM (Patrick L. Mahan)\nSubject: Re: Is it just me, or is this newsgroup dead?\nOrganization: The Internet\nLines: 24\nNNTP-Posting-Host: enterpoop.mit.edu\nTo: xpert@expo.lcs.mit.edu, rlm@helen.surfcty.com\n\n#\n# I've gotten very few posts on this group in the last couple days.  (I\n# recently added it to my feed list.)  Is it just me, or is this group\n# near death?\n#\n\nSeen from the mailing list side, I'm getting about the right amount of\ntraffic.\n\nPatrick L. Mahan\n\n--- TGV Window Washer ------------------------------- Mahan@TGV.COM ---------\n\nWaking a person unnecessarily should not be considered  - Lazarus Long\na capital crime.  For a first offense, that is            From the Notebooks of\n\t\t\t\t\t\t\t  Lazarus Long\n\nPatrick L. Mahan\n\n--- TGV Window Washer ------------------------------- Mahan@TGV.COM ---------\n\nWaking a person unnecessarily should not be considered  - Lazarus Long\na capital crime.  For a first offense, that is            From

In [5]:
class NewsDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        return {'input_ids': encoding['input_ids'].squeeze(), 'labels': torch.tensor(label,dtype=torch.long)}

In [6]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(newsgroups.target_names))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
train_dataset = NewsDataset(X_train, y_train, tokenizer)
test_dataset = NewsDataset(X_test, y_test, tokenizer)

In [8]:
# from torch.utils.data import DataLoader

# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [9]:
test_dataset[0]

{'input_ids': tensor([  101,  2013,  1024,  2911,  4842,  2102,  1030, 10507,  2080,  1012,
         10250, 15007,  1012,  3968,  2226,  1006,  5199,  2911,  4842,  2102,
          1007,  3395,  1024,  2128,  1024,  1999,  3790,  4875,  3627,  3029,
          1024,  2662,  2820,  1997,  2974,  1010, 18880,  3210,  1024,  2423,
          1050,  3372,  2361,  1011, 14739,  1011,  3677,  1024,  5472,  2386,
          1012, 10250, 15007,  1012,  3968,  2226,  3781, 22844,  4246,  1030,
          3660,  1012,  8301, 22117,  5686,  1012,  3968,  2226,  1006,  6108,
         20996,  3995,  4246,  1007,  7009,  1024,  1028,  2028,  2197,  1999,
          3790,  4875,  3160,  2008,  2038,  2467, 14909,  2033,  1998,  8440,
          1005,  1056,  1028,  2664,  2042,  8280,  1012,  1045,  2903,  1996,
          3627,  2036,  2515,  1008,  2025,  1008,  3066,  2007,  2023,  1028,
          3663,  1024,  1028,  2174,  1010,  2065,  1996,  1999,  3790,  4875,
          2003,  1008,  2025,  1008,  3

In [14]:
from networkx import davis_southern_women_graph


training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    logging_dir='./logs',
    logging_steps=10,
    save_steps=20,
    eval_strategy="steps",
    eval_steps=10,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

In [15]:
trainer.train()()

                                                   
  0%|          | 20/4715 [10:56<2:34:58,  1.98s/it]

{'loss': 2.8848, 'grad_norm': 5.042397975921631, 'learning_rate': 4.982325910215624e-05, 'epoch': 0.01}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                

  0%|          | 20/4715 [24:51<2:34:58,  1.98s/it]
[A
[A

{'eval_loss': 2.911072015762329, 'eval_runtime': 834.7725, 'eval_samples_per_second': 4.516, 'eval_steps_per_second': 0.141, 'epoch': 0.01}


                                                   
  0%|          | 20/4715 [25:26<2:34:58,  1.98s/it] 

{'loss': 2.7823, 'grad_norm': 8.89710521697998, 'learning_rate': 4.964651820431248e-05, 'epoch': 0.02}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                 

  0%|          | 20/4715 [39:30<2:34:58,  1.98s/it]
[A
[A

{'eval_loss': 2.6679115295410156, 'eval_runtime': 843.4878, 'eval_samples_per_second': 4.47, 'eval_steps_per_second': 0.14, 'epoch': 0.02}


                                                   
  0%|          | 20/4715 [40:06<2:34:58,  1.98s/it] 

{'loss': 2.5924, 'grad_norm': 6.133090019226074, 'learning_rate': 4.946977730646872e-05, 'epoch': 0.03}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                 

  0%|          | 20/4715 [53:29<2:34:58,  1.98s/it]
[A
[A

{'eval_loss': 2.447645425796509, 'eval_runtime': 802.9969, 'eval_samples_per_second': 4.695, 'eval_steps_per_second': 0.147, 'epoch': 0.03}


                                                   
  0%|          | 20/4715 [54:02<2:34:58,  1.98s/it] 

{'loss': 2.306, 'grad_norm': 6.286175727844238, 'learning_rate': 4.929303640862496e-05, 'epoch': 0.04}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                 

  0%|          | 20/4715 [1:07:22<2:34:58,  1.98s/it]
[A
[A

{'eval_loss': 2.164458751678467, 'eval_runtime': 800.4463, 'eval_samples_per_second': 4.71, 'eval_steps_per_second': 0.147, 'epoch': 0.04}


                                                     
  0%|          | 20/4715 [1:07:57<2:34:58,  1.98s/it]

{'loss': 2.1284, 'grad_norm': 9.0938138961792, 'learning_rate': 4.91162955107812e-05, 'epoch': 0.05}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                     
[A                                                 

  0%|          | 20/4715 [1:26:41<2:34:58,  1.98s/it]
[A
[A

{'eval_loss': 1.9375823736190796, 'eval_runtime': 1123.3789, 'eval_samples_per_second': 3.356, 'eval_steps_per_second': 0.105, 'epoch': 0.05}


                                                     
  0%|          | 20/4715 [1:27:14<2:34:58,  1.98s/it] 

{'loss': 1.8877, 'grad_norm': 8.512142181396484, 'learning_rate': 4.893955461293744e-05, 'epoch': 0.06}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                     
[A                                                   

  0%|          | 20/4715 [1:44:05<2:34:58,  1.98s/it]
[A
[A

{'eval_loss': 1.8302439451217651, 'eval_runtime': 1011.509, 'eval_samples_per_second': 3.727, 'eval_steps_per_second': 0.117, 'epoch': 0.06}


                                                     
  0%|          | 20/4715 [1:45:00<2:34:58,  1.98s/it] 

{'loss': 1.9236, 'grad_norm': 7.317235946655273, 'learning_rate': 4.876281371509368e-05, 'epoch': 0.07}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

KeyboardInterrupt: 

In [None]:
predictions = trainer.predict(test_dataset)
pred_labels = predictions.predictions.argmax(-1)
accuracy = accuracy_score(y_test, pred_labels)
print(f"Accuracy: {accuracy}")

# Classification report
print(classification_report(y_test, pred_labels, target_names=newsgroups.target_names))



Accuracy: 0.07214854111405836
                          precision    recall  f1-score   support

             alt.atheism       0.00      0.00      0.00       151
           comp.graphics       0.00      0.00      0.00       202
 comp.os.ms-windows.misc       0.00      0.00      0.00       195
comp.sys.ibm.pc.hardware       0.00      0.00      0.00       183
   comp.sys.mac.hardware       0.00      0.00      0.00       205
          comp.windows.x       0.00      0.00      0.00       215
            misc.forsale       0.13      0.82      0.22       193
               rec.autos       0.00      0.00      0.00       196
         rec.motorcycles       0.00      0.00      0.00       168
      rec.sport.baseball       0.00      0.00      0.00       211
        rec.sport.hockey       0.07      0.55      0.13       198
               sci.crypt       0.00      0.00      0.00       201
         sci.electronics       0.00      0.00      0.00       202
                 sci.med       0.00      0.00

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
