In [1]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/News_Classification

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/.shortcut-targets-by-id/1EgYdJ-mvNhzcz5FT18MSZz5pO5bmDV9d/News_Classification


In [2]:
!pip install torch
!pip install transformers



In [3]:
import os
import time
import json
import logging
from dataclasses import dataclass
from tqdm import tqdm
from functools import partial
tqdm = partial(tqdm, position=0, leave=True)


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from transformers import pipeline
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

In [4]:
checkpoint_path = os.path.join(os.getcwd(), "checkpoints")
dataset_path = os.path.join(os.getcwd(), 'dataset')
test_data_path = os.path.join(dataset_path, 'processed_test.json')

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

batch_size = 8
max_token_len = 1024
temperature = 0.1
checkpoint_name_list = os.listdir(checkpoint_path)[7: 9]  # choose 3 to 4 models

In [6]:
checkpoint_name_list

['checkpoint_3249.bin', 'checkpoint_3610.bin']

In [7]:
classifier = pipeline(
    'zero-shot-classification',
    tokenizer='facebook/bart-large-mnli',
    model='facebook/bart-large-mnli',
    device=device.index,
    framework='pt',
)

In [8]:
# Reference from: https://github.com/yinwenpeng/BenchmarkingZeroShot/blob/master/src/train_yahoo.py
hypothesis_template = "This text {}."

choice_to_hypothesis = {
    "Containment and Closure Policies": [
        'is related with containment and closure policy from governments in the pandemic',
    ],
    "Economic Policies": [
        'is related with economic policy from governments in the pandemic', 
    ],
    "Health System Policies": [
        'is related with health system policy from governments in the pandemic',
    ],
    "Miscellaneous Policies": [
        'is related with miscellaneous policy from governments in the pandemic', 
    ]
}

for value in choice_to_hypothesis.values():
    value[0] = hypothesis_template.format(value[0])

print(json.dumps(choice_to_hypothesis, indent=4))

{
    "Containment and Closure Policies": [
        "This text is related with containment and closure policy from governments in the pandemic."
    ],
    "Economic Policies": [
        "This text is related with economic policy from governments in the pandemic."
    ],
    "Health System Policies": [
        "This text is related with health system policy from governments in the pandemic."
    ],
    "Miscellaneous Policies": [
        "This text is related with miscellaneous policy from governments in the pandemic."
    ]
}


# Prepare Data

In [9]:
@dataclass
class CovidNewsData(object):
    id: str
    premise: str = None
    hypothesis: str = None
    label: str = None

In [10]:
class CovidNewsDataset(Dataset):
    """
    Reference from https://huggingface.co/transformers/_modules/transformers/pipelines/zero_shot_classification.html#ZeroShotClassificationPipeline.__call__
    """
    def __init__(
        self, 
        data_path=None, 
        *,
        tokenizer=None, 
        choice_to_hypothesis=None,
        target_text='summary', 
        max_token_len=1024, 
        device="cpu", 
        transform=None,
      ):
        """
        Args: 
            data_path (str): The full path of the dataset. Required.
            target_text (str): Either use 'summary' or 'article' as inputs of the model. Default: 'summary'.
            tokenizer: The model's tokenizer. Required.
        """

        assert data_path is not None, f"[self.__class__.__name__] Please specify a data path."
        assert tokenizer is not None, f"[self.__class__.__name__] Please give a tokenizer."
        assert isinstance(choice_to_hypothesis, dict), f"[self.__class__.__name__] Please give a dictionary for choices to hypothesis."
        assert target_text in ['summary', 'article'], f"[self.__class__.__name__] Please pick a target_text from either 'summary' or 'article."
        self.data_path = data_path
        self.target_text = target_text
        self.tokenizer = tokenizer
        self.choice_to_hypothesis = choice_to_hypothesis
        self.class_to_id = {
            "Contradiction": 0, 
            "Neutral": 1, 
            "Entailment": 2, 
        }
        self.max_token_len = max_token_len if max_token_len < tokenizer.model_max_length else tokenizer.model_max_length
        self.device = device
        self.transform = transform

        # Init data
        self.data_list = self._get_data()

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()

        data = self.data_list[idx]
        encoding = self.tokenizer(
            self._get_sentence_pair(data), 
            return_tensors='pt', 
            padding='max_length',
            add_special_tokens=True, 
            truncation='only_first',  # prevent from truncating the hypothesis
            max_length=self.max_token_len
        )
        encoding['input_ids'] = encoding['input_ids'].squeeze()
        encoding['attention_mask'] = encoding['attention_mask'].squeeze()

        sample = {
            'id': data.id,
            'inputs': self._ensure_tensor_on_device(**encoding), 
            'label': torch.tensor(self.class_to_id[data.label], dtype=torch.long).to(device)
        }

        if self.transform is not None:
            sample = self.transform(sample)

        return sample

    def _get_sentence_pair(self, data):

        return [[data.premise, data.hypothesis]]

    def _ensure_tensor_on_device(self, **inputs):

        return {name: tensor.to(self.device) for name, tensor in inputs.items()}

    def _get_data(self):
        with open(self.data_path) as f:
            raw_data_list = json.load(f)

        labels = list(self.class_to_id.keys())
        data_list = []
        for data in raw_data_list:
            current_choice = data['choice']
            id = data['id']
            premise = data[self.target_text]

            # Entailment, 3 from the current choice
            label = labels[2]   # Entailment
            for hypothesis in self.choice_to_hypothesis[current_choice]:
                data_list.append(CovidNewsData(
                    id=id,
                    premise=premise,
                    hypothesis=hypothesis,
                    label=label,
                ))

            # Contradiction or Not Entailment, 3 from other choices seperately
            label = labels[0]   # Contradiction or Not Entailment
            for other_choice in self.choice_to_hypothesis.keys():
                if current_choice == other_choice: continue

                randIdx = torch.randperm(
                    len(self.choice_to_hypothesis[other_choice])
                )[0].item()
                hypothesis = self.choice_to_hypothesis[other_choice][randIdx]
                data_list.append(CovidNewsData(
                    id=id,
                    premise=premise,
                    hypothesis=hypothesis,
                    label=label,
                ))

        return data_list


In [11]:
test_dataset = CovidNewsDataset(
    test_data_path, 
    tokenizer=classifier.tokenizer, 
    choice_to_hypothesis=choice_to_hypothesis,
    device=device
)
test_data_loader = DataLoader(
    dataset=test_dataset, 
    batch_size=batch_size,
    shuffle=False
)

In [12]:
print(f"Length of the testing dataset: {len(test_dataset)}")
print(test_dataset[0])

Length of the testing dataset: 1744
{'id': 1342596, 'inputs': {'input_ids': tensor([    0, 31921, 12127,  ...,     1,     1,     1], device='cuda:0'), 'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0], device='cuda:0')}, 'label': tensor(2, device='cuda:0')}


# Prepare for Analyzing

In [13]:
criterion = nn.KLDivLoss(reduction='sum')

# Analyzing

## Zero-shot (Uncomment to get the output distribution of ***bart-large-mnli*** given the dataset.)

In [14]:
# if 'classifier' not in globals() or 'classifier' not in locals():
#     classifier = pipeline(
#         'zero-shot-classification',
#         tokenizer='facebook/bart-large-mnli',
#         model='facebook/bart-large-mnli',
#         device=device.index,
#         framework='pt',
#     )
# classifier.model.eval()

# losses = 0.
# zero_shot_distribution = []

# with torch.no_grad():
#     for batch_idx, batch in enumerate(tqdm(test_data_loader)):
#         inputs = batch['inputs']
#         label = batch['label']
#         one_hot_label = F.one_hot(label, num_classes=3)

#         outputs = classifier.model(**inputs)
#         output_distribution = (outputs.logits/temperature).softmax(dim=-1)

#         loss = criterion(output_distribution.log(), one_hot_label)
#         losses += loss.item()

#         zero_shot_distribution.append(output_distribution)
#     losses /= len(test_dataset)

# print(' checkpoint: {} | loss (label): {:5.3f}'.format("facebook/bart-large-mnli", losses))

# torch.save(zero_shot_distribution, os.path.join(dataset_path, f"zero_shot_distribution_temp_{temperature}.pt"))

## Few-shots & fine-tuned

In [15]:
if 'zero_shot_distribution' not in locals() or 'zero_shot_distribution' not in globals():
    zero_shot_distribution = torch.load(os.path.join(dataset_path, f"zero_shot_distribution_temp_{temperature}.pt"), map_location=device)

loss_label_record_path = os.path.join(dataset_path, f"loss_label_record_temp_{temperature}.pt")
loss_zero_shot_record_path = os.path.join(dataset_path, f"loss_zero_shot_record_temp_{temperature}.pt")
loss_label_record = torch.load(loss_label_record_path) if os.path.exists(loss_label_record_path) else {}
loss_zero_shot_record = torch.load(loss_zero_shot_record_path) if os.path.exists(loss_zero_shot_record_path) else {}

for checkpoint_name in checkpoint_name_list:
    if 'classifier' in locals() or 'classifier' in globals():
        # Can't delete clearly
        print("Delete previous classifier")
        classifier.model.to('cpu')
        del classifier
        torch.cuda.empty_cache()

    # Load checkpoints
    model_path = os.path.join(checkpoint_path, checkpoint_name)
    classifier = pipeline(
        'zero-shot-classification',
        tokenizer='facebook/bart-large-mnli',
        model=model_path,
        device=device.index,
        framework='pt',
    )
    classifier.model.eval()
    print(f"Current checkpoint: {checkpoint_name}")

    losses, losses_zero_shot = 0., 0.
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(test_data_loader)):
            inputs = batch['inputs']
            label = batch['label']
            one_hot_label = F.one_hot(label, num_classes=3)

            outputs = classifier.model(**inputs)
            output_distribution = (outputs.logits/temperature).softmax(dim=-1)

            loss = criterion(output_distribution.log(), one_hot_label)
            losses += loss.item()

            loss = criterion(outputs.logits, zero_shot_distribution[batch_idx])
            losses_zero_shot += loss.item()
    losses /= len(test_dataset)
    losses_zero_shot /= len(test_dataset)

    loss_label_record[checkpoint_name] = losses
    loss_zero_shot_record[checkpoint_name] = losses_zero_shot

    print(' checkpoint: {} | loss (label): {:5.3f} | loss (zero-shot): {:5.3f}'.format(checkpoint_name, losses, losses_zero_shot))

Delete previous classifier


  0%|          | 0/218 [00:00<?, ?it/s]

Current checkpoint: checkpoint_3249.bin


100%|██████████| 218/218 [09:21<00:00,  2.58s/it]


 checkpoint: checkpoint_3249.bin | loss (label): 9.849 | loss (zero-shot): 2.006
Delete previous classifier


  0%|          | 0/218 [00:00<?, ?it/s]

Current checkpoint: checkpoint_3610.bin


100%|██████████| 218/218 [09:37<00:00,  2.65s/it]

 checkpoint: checkpoint_3610.bin | loss (label): 11.035 | loss (zero-shot): 1.856





In [16]:
torch.save(loss_label_record, os.path.join(dataset_path, f"loss_label_record_temp_{temperature}.pt"))
torch.save(loss_zero_shot_record, os.path.join(dataset_path, f"loss_zero_shot_record_temp_{temperature}.pt"))

In [17]:
loss_label_record

{'checkpoint_1083.bin': 3.009235018305335,
 'checkpoint_1444.bin': 4.105253546325695,
 'checkpoint_1805.bin': 4.751828841851915,
 'checkpoint_2166.bin': 6.514452876382762,
 'checkpoint_2527.bin': 8.461437925422981,
 'checkpoint_2888.bin': 8.945213681164542,
 'checkpoint_3249.bin': 9.849357470922586,
 'checkpoint_361.bin': 2.212872189939569,
 'checkpoint_3610.bin': 11.03453324741014,
 'checkpoint_722.bin': 1.3737209406467752}

In [18]:
loss_zero_shot_record

{'checkpoint_1083.bin': 1.7673880211120352,
 'checkpoint_1444.bin': 1.9427076656747302,
 'checkpoint_1805.bin': 2.0233392078395283,
 'checkpoint_2166.bin': 1.8364907042422425,
 'checkpoint_2527.bin': 2.0709450818393207,
 'checkpoint_2888.bin': 1.8245160898210806,
 'checkpoint_3249.bin': 2.006164341276392,
 'checkpoint_361.bin': 1.4395465304682014,
 'checkpoint_3610.bin': 1.8559674210162884,
 'checkpoint_722.bin': 1.6387985161411653}