In [None]:
#default_exp bart

# BART Zero-Shot Prediction Classification 

> Using modern zero-shot classification techniques

The starter code is helpful, but Huggingface has built-in tools for zero shot classification. This also makes it easier to test different models (some may be trained on a larger amount of scientific text, which will be helpful).

In [None]:
#export
from ought.starter import *
from transformers import pipeline

Here, we'll use FaceBook's [BART](https://huggingface.co/facebook/bart-large-mnli) model. It was explicitly designed for zero-shot text classification (among other tasks), and should work well out of the box. The same prompts as GPT-2 are used for consistency, though there is potentially some scope for tuning here.

In [None]:
clas = pipeline("zero-shot-classification", device=0)
labels = ["AI", "Not AI"]

In [None]:
samples = load_jsonl("data/train.jsonl")
prompt = make_prompt('Label each of the following examples as "AI" or "NOT AI"', samples[:5], samples[5])
pred = clas(prompt, labels)

## Refactor into a Single Class

We can refactor all this and export it as a single class with two useful methods:

- An initializer that will retrain a new model for *every* new instance. This is intended, since we do not know the training set ahead of time. One potential improvement here would be to continuously train on every new `.jsonl` file that comes in and save the weights, but there is not enough data for that here. 
- A `predict` method that takes in a sentence and returns a prediction by querying the trained model.

In [None]:
#export
class BARTClassifier:
    def __init__(self, instructions='Label each of the following examples as "AI" or "NOT AI"', json='data/train.jsonl', samples=2):
        self.instructions = instructions
        self.context = uniform_samples(json, samples)
        self.clas = pipeline("zero-shot-classification", device=0)
        self.labels = ["AI", "Not AI"]
        
    def predict(self, prompt):
        prompt = make_prompt(self.instructions, self.context, {'text': prompt})
        # to create a concrete prediction, take the last line and strip the "LABEL: " component 
        result = self.clas(prompt, self.labels)
        pred = 'Not AI' if result['scores'][0] > result['scores'][1] else 'AI' 
        return pred

> Note: you might have to restart the notebook to clear GPU memory at this point

In [None]:
test = load_jsonl("data/test_no_labels.jsonl")
example = test[0]
prompt = example['text']
prompt

'out of plane effect on the superconductivity of sr2 xbaxcuo3+d with tc up to 98k. we comment on the paper published by w.b. gao q.q. liu l.x. yang y.yu f.y. li c.q. jin and s. uchida in phys. rev. b and give alternate explanations for the enhanced superconductivity. the enhanced onset tc of 98k observed upon substituting ba for sr is attributed to optimal oxygen ordering rather than to the increase in volume. comparison with la2cuo +x samples suggest that the effect of disorder is overestimated.'

In [None]:
%%time
clas = BARTClassifier()
pred = clas.predict(prompt)

Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartModel: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification m

CPU times: user 22.9 s, sys: 1.16 s, total: 24.1 s
Wall time: 24.5 s


In [None]:
pred

'Not AI'