## Multi Label classification

#### Topic Modelling on the Reuters Dataset. 

Binary classification is where an input can be classified in to 2 categories

Multi-class classification is where an input can be classified to any ONE of many categories

Multi-label classification is where an input can be classified to ANY NUMBER of many categories 

In this notebook we'll explore multi-label classification within the field of topic modelling. In our case, stating the topics associated with a reuters news article. 

#### Inputs
- nltk is the natural language toolkit, where the reuters dataset is stored
- the nltk corpus contains information about the reuters dataset
- torch (pytorch) libraries contain helpful tools for deep learning including the dataloader and optimiser
- the model and tokeniser are loaded from the transformers library. 
- label_ranking_average_precision_score is an evaluation tool
-  the counter class is helpful for complex counting over iterables



In [5]:
import nltk 
from nltk.corpus import reuters

import torch.utils.data
import torch.nn as nn 
import torch.optim as optim

from transformers import AutoModelForSequenceClassification , AutoTokenizer

from sklearn.metrics import label_ranking_average_precision_score

from tqdm import tqdm
from collections import Counter



Download the dataset

In [184]:
nltk.download("reuters")

[nltk_data] Downloading package reuters to /Users/Ben/nltk_data...
[nltk_data]   Package reuters is already up-to-date!


True

There are 90 categories which the documents can be labeled by

In [185]:
reuters_documents = reuters.fileids()
reuters_categories = reuters.categories()

This function returns information about the reuters dataset, including the test:train split. 

In [187]:
def dataset_info(documents):

    """Information about Reuters dataset, such as number of training and test documents, and categories"""
    train_docs = list(filter(lambda doc: doc.startswith("train"), documents))
    test_docs = list(filter(lambda doc: doc.startswith("test"), documents))
    # 10788 documents, 7769 for training and 3019 for test
    print(str(len(documents)) + " documents")
    print(str(len(train_docs)) + " total train documents")
    print(str(len(test_docs)) + " total test documents")

In [188]:
dataset_info(reuters_documents)

10788 documents
7769 total train documents
3019 total test documents


This function returns the number of documents which are labeleld by each of the 90 categories. As this dataset is multi-label the sum of these figures will be more than the number of documents

In [189]:
def documents_per_category(categories):
    """Return the number of documents per category"""
    def get_category_length_tuple(cat):
        return (len(reuters.fileids(cat)), cat)
    return [get_category_length_tuple(cat) for cat in categories]



In [190]:
documents_per_category(reuters_categories)

[(2369, 'acq'),
 (58, 'alum'),
 (51, 'barley'),
 (105, 'bop'),
 (68, 'carcass'),
 (2, 'castor-oil'),
 (73, 'cocoa'),
 (6, 'coconut'),
 (7, 'coconut-oil'),
 (139, 'coffee'),
 (65, 'copper'),
 (3, 'copra-cake'),
 (237, 'corn'),
 (59, 'cotton'),
 (3, 'cotton-oil'),
 (97, 'cpi'),
 (4, 'cpu'),
 (578, 'crude'),
 (3, 'dfl'),
 (175, 'dlr'),
 (14, 'dmk'),
 (3964, 'earn'),
 (23, 'fuel'),
 (54, 'gas'),
 (136, 'gnp'),
 (124, 'gold'),
 (582, 'grain'),
 (9, 'groundnut'),
 (2, 'groundnut-oil'),
 (19, 'heat'),
 (22, 'hog'),
 (20, 'housing'),
 (16, 'income'),
 (6, 'instal-debt'),
 (478, 'interest'),
 (53, 'ipi'),
 (54, 'iron-steel'),
 (5, 'jet'),
 (67, 'jobs'),
 (8, 'l-cattle'),
 (29, 'lead'),
 (15, 'lei'),
 (2, 'lin-oil'),
 (99, 'livestock'),
 (16, 'lumber'),
 (49, 'meal-feed'),
 (717, 'money-fx'),
 (174, 'money-supply'),
 (6, 'naphtha'),
 (105, 'nat-gas'),
 (9, 'nickel'),
 (3, 'nkr'),
 (4, 'nzdlr'),
 (14, 'oat'),
 (171, 'oilseed'),
 (27, 'orange'),
 (3, 'palladium'),
 (40, 'palm-oil'),
 (3, 'palmkern

In [191]:
def categories_per_documents(documents):
    """Reuters contains multilabeled documents.
    This method returns the number of labels and the corresponding number of documents
    e.g. 2:1173 means that there are 1173 documents with 2 categories (multilabel)
    """
    def categories_per_document(fid):
        return (len(reuters.categories(fid)), fid)


    list_of_categories_per_doc = [
    categories_per_document(doc) for doc in documents]
    # Returns the number of documents that fall in multiple categories
    return(Counter([a for (a, b) in list_of_categories_per_doc]))

In [192]:
def categories_per_document(fid):
        return (len(reuters.categories(fid)), fid)

[(reuters.categories(doc),doc) for doc in reuters_documents if categories_per_document(doc)[0] == 15]




[(['coffee',
   'copra-cake',
   'corn',
   'cotton',
   'grain',
   'palm-oil',
   'palmkernel',
   'rice',
   'rubber',
   'soy-meal',
   'soybean',
   'sugar',
   'tea',
   'veg-oil',
   'wheat'],
  'training/235')]

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

The next few cells set up  the test and train dataset and initialise the tokeniser and model. 

In [194]:
class Reuters(torch.utils.data.Dataset):
    def __init__(self,mode = 'train', tokenise=False):
        nltk.download("reuters")
        self.fileids = list(filter(lambda doc: doc.startswith(mode), reuters.fileids()))
        self.text = [reuters.raw(fid) for fid in self.fileids]
        self.category_to_index = {cat:index for (index,cat) in enumerate(reuters.categories())} 
        self.label = [[self.category_to_index[cat] for cat in reuters.categories(fid)] for fid in self.fileids]
        self.mode = mode

    def __len__(self):
        return len(self.fileids)

    def __getitem__(self,index):
        return (self.text[index], self.label[index])
   
train_dataset = Reuters("train")
test_dataset = Reuters("test")

print(len(train_dataset), len(test_dataset))
print(train_dataset[100])



[nltk_data] Downloading package reuters to /Users/Ben/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package reuters to /Users/Ben/nltk_data...
[nltk_data]   Package reuters is already up-to-date!


7769 3019
('AMERICAN STORES &lt;ASC> SEES LOWER YEAR NET\n  American Stores Co said it\n  expects to report earnings per share of 3.70 to 3.85 dlrs per\n  share on sales of slightly over 14 billion dlrs for the year\n  ended January 31.\n      The supermarket chain earned 4.11 dlrs per share on sales\n  of 13.89 billion dlrs last year.\n      The company did not elaborate.\n  \n\n', [21])


In [196]:
MODEL_NAME = 'distilbert-base-uncased'
NUMBER_OF_CLASSES = 90
BATCH_SIZE = 5
NUMBER_OF_EPOCHS = 5


In [197]:
tokeniser = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUMBER_OF_CLASSES)


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.bias', 'pre_classi

The commented out line allows us to test the evaluation method runs without running over the whole large dataset by reducing it to a single value. 

In [222]:
tokenised_test = tokeniser(test_dataset.text,truncation=True,padding=True) # these are all default args because of the tokeniser we've loaded with the model we've loaded. 
tokenised_train = tokeniser(train_dataset.text,truncation=True,padding=True)

tokenised_test_dataset = [{"labels":nn.functional.one_hot(torch.tensor(label),num_classes=NUMBER_OF_CLASSES).sum(dim=0), "input_ids": text,"attention_mask":mask} for label,text,mask in zip(test_dataset.label,tokenised_test["input_ids"],tokenised_test["attention_mask"])]
tokenised_train_dataset = [{"labels":nn.functional.one_hot(torch.tensor(label),num_classes=NUMBER_OF_CLASSES).sum(dim=0), "input_ids": text,"attention_mask":mask} for label,text,mask in zip(train_dataset.label,tokenised_train["input_ids"],tokenised_train["attention_mask"])]

# tokenised_train_dataset = tokenised_train_dataset[:1]

In [200]:
train_dataloader = torch.utils.data.DataLoader(tokenised_train_dataset, batch_size=BATCH_SIZE, shuffle = True)

In [201]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=model.to(DEVICE)

In [217]:

def create_position_weights(tokenised_train_dataset):
    labels = torch.stack([x["labels"]   for x in tokenised_train_dataset ])
    total_positive_samples = labels.sum().item()
    total_negative_samples = torch.numel(labels)-total_positive_samples

    # print(total_positive_samples,total_negative_samples)
    positive_per_label = labels.sum(dim=0)
    position_weights = total_negative_samples/positive_per_label
    return position_weights

In [218]:
params_to_update = model.parameters()
optimiser = optim.AdamW(params_to_update)
position_weights = create_position_weights(tokenised_train_dataset)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=position_weights)


In [50]:
running_loss = 0
running_correct = 0

for epoch in tqdm(range(NUMBER_OF_EPOCHS)):
    for batch in train_dataloader: 
        optimiser.zero_grad()
        input_ids = torch.stack(batch["input_ids"], 1)
        attention_mask = torch.stack(batch["attention_mask"],1)
        labels = batch["labels"].to(DEVICE)
        outputs = model(input_ids,attention_mask= attention_mask)
        loss = loss_fn(outputs["logits"], labels.float())
        loss.backward()
        optimiser.step()

        running_loss += loss.item() * input_ids.size(0)
        preds = (torch.sigmoid(outputs["logits"])>0.5).int()
        running_correct += torch.sum(preds==labels)
    epoch_loss = running_loss/len(train_dataloader.dataset)
    epoch_accuracy = running_correct.double()/(len(train_dataloader.dataset)*NUMBER_OF_CLASSES) 
    print("Epoch loss: ", epoch_loss)
    print("Epoch Accuracy: ", epoch_accuracy.item())


        

 20%|██        | 1/5 [00:02<00:11,  2.88s/it]

Epoch loss:  0.018846383318305016
Epoch Accuracy:  0.9888888888888889


 40%|████      | 2/5 [00:05<00:08,  2.95s/it]

Epoch loss:  0.03621345944702625
Epoch Accuracy:  1.9777777777777779


 60%|██████    | 3/5 [00:08<00:05,  2.97s/it]

Epoch loss:  0.05208074487745762
Epoch Accuracy:  2.966666666666667


 80%|████████  | 4/5 [00:11<00:02,  2.95s/it]

Epoch loss:  0.06613552011549473
Epoch Accuracy:  3.9555555555555557


100%|██████████| 5/5 [00:14<00:00,  2.94s/it]

Epoch loss:  0.08071557525545359
Epoch Accuracy:  4.944444444444445





In [206]:
MODEL_PATH = f"models/reuters_{MODEL_NAME}.pth"
model.load_state_dict(torch.load(MODEL_PATH,map_location=torch.device('cpu')))

<All keys matched successfully>

In [266]:
test_idx = 2
test_sample = tokenised_test_dataset[test_idx]
test_input = torch.tensor(test_sample["input_ids"],device = DEVICE).unsqueeze(0)
test_mask = torch.tensor(test_sample["attention_mask"],device = DEVICE).unsqueeze(0)
test_label = test_sample["labels"]

In [294]:
test_input

tokeniser.convert_tokens_to_string(tokeniser.convert_ids_to_tokens([103,  2900,  2000,  7065,  5562,  2146,  1011,  2744,  2943,  5157,
         28457,  1996,  3757,  1997,  2248,  3119,  1998,  3068,  1006, 10210,
          2072,  1007,  2097,  7065,  5562,  2049,  2146,  1011,  2744,  2943,
          4425,  1013,  5157, 17680,  2011,  2257,  2000,  3113,  1037, 19939,
          2091,  7913,  4859,  1999,  2887,  2943,  5157,  1010,  3757,  4584,
          2056,  1012, 10210,  2072,  2003,  3517,  2000,  2896,  1996, 13996,
          2005,  3078,  2943,  6067,  1999,  1996,  2095,102]))

'[MASK] japan to revise long - term energy demand downwards the ministry of international trade and industry ( miti ) will revise its long - term energy supply / demand outlook by august to meet a forecast downtrend in japanese energy demand, ministry officials said. miti is expected to lower the projection for primary energy supplies in the year [SEP]'

In [268]:
# test_sample
# test_dataset[2]

("JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS\n  The Ministry of International Trade and\n  Industry (MITI) will revise its long-term energy supply/demand\n  outlook by August to meet a forecast downtrend in Japanese\n  energy demand, ministry officials said.\n      MITI is expected to lower the projection for primary energy\n  supplies in the year 2000 to 550 mln kilolitres (kl) from 600\n  mln, they said.\n      The decision follows the emergence of structural changes in\n  Japanese industry following the rise in the value of the yen\n  and a decline in domestic electric power demand.\n      MITI is planning to work out a revised energy supply/demand\n  outlook through deliberations of committee meetings of the\n  Agency of Natural Resources and Energy, the officials said.\n      They said MITI will also review the breakdown of energy\n  supply sources, including oil, nuclear, coal and natural gas.\n      Nuclear energy provided the bulk of Japan's electric power\n  in the fisc

In [153]:
test_train_idx = 2350
test_train_sample = tokenised_train_dataset[test_train_idx]
test_train_input = torch.tensor(test_train_sample["input_ids"],device = DEVICE).unsqueeze(0)
test_train_mask = torch.tensor(test_train_sample["attention_mask"],device = DEVICE).unsqueeze(0)

In [154]:
train_dataset.text[2350]

'BUFFTON CORP &lt;BUFF> BUYS B AND D INSTRUMENTS\n  Buffton Corp said it completed\n  the purchase of B and D Industruments Inc for two mln dlrs cash\n  and 400,000 shares of common stock.\n      It said B and D is a private company headquartered in\n  Kansas, and had sales of 4,700,000 dlrs in 1986.\n      Buffton said the company designs and manufactures aviation\n  computer display systems and engine instrumentation.\n  \n\n'

In [None]:
outputs = model(test_input,attention_mask= test_mask)["logits"]
# loss_fn(outputs, labels.float())

print("score",label_ranking_average_precision_score(test_label.unsqueeze(0), outputs.detach()))

preds = (torch.sigmoid(outputs)>0.5).int()
# preds = torch.sigmoid(outputs)
print(preds,test_label)


In [243]:
test_train_mask
test_sample
test_label.shape

torch.Size([90])

In [254]:
test_idx = 10
test_samples = tokenised_test_dataset[:test_idx]
test_inputs = torch.stack([torch.tensor(test_sample["input_ids"],device = DEVICE) for test_sample in test_samples])
test_masks = torch.stack([torch.tensor(test_sample["attention_mask"],device = DEVICE) for test_sample in test_samples])
test_labels = torch.stack([test_sample["labels"] for test_sample in test_samples])

In [265]:
test_input

tensor([[  101, 13675,  2050,  2853, 16319,  2751,  2005,  6146, 19875,  2078,
         21469,  2869,  1011,  1059, 14341,  3636,  1004,  8318,  1025,  1059,
         14341,  3636, 10495, 17953,  1028,  2056,  1996, 12360,  2009,  2003,
          2877,  2097,  3477,  6146,  1012,  4583, 19875,  2078, 21469,  2869,
          2005,  1996,  7654,  1997, 13675,  2050,  5183,  1005,  1055,  1004,
          8318,  1025, 13675, 11057,  1012,  1055,  1028,  1004,  8318,  1025,
         16319,  2751, 13866,  2100,  5183,  1028,  3131,  1010,  2988,  7483,
          1012, 13675,  2050,  1998,  1059, 14341,  3636,  2106,  2025, 26056,
          1996,  3976,  7483,  1012,  1059, 14341,  3636,  2097,  2907,  4008,
          7473,  2102,  1997,  1996, 12360,  1010,  2096,  1004,  8318,  1025,
         17151,  2102,  2860, 14341,  4219, 17953,  1028,  2097,  2907,  2676,
          7473,  2102,  1998,  1004,  8318,  1025, 13675, 22504,  2271,  5471,
         17953,  1028,  2756,  7473,  2102,  1010,  

In [261]:
outputs = model(test_inputs,attention_mask= test_masks)["logits"]
# loss_fn(outputs, labels.float())

print("score",label_ranking_average_precision_score(test_labels, outputs.detach()))

# preds = (torch.sigmoid(outputs)>0.5).int()
# # preds = torch.sigmoid(outputs)
# print(preds,test_labels)


score 0.16754075701311438
