# Tutorial: Zero-Shot Learning For Substance Use Text Analysis

In this tutorial we are going to implement Zero-Shot Learning (ZSL) using the BART (Bidirectional Auto-Regressive Transformer) model. This notebook is available in two forms:

1. [Online (Google Colab)](https://colab.research.google.com/github/ltu-capr/zsl-text-tutorial/blob/master/ZSL_for_substance_use_text_analysis.ipynb): For experimenting on Google's free platform without installing anything on your computer.
2. [Offline (Jupyter Notebook)](https://github.com/ltu-capr/zsl-text-tutorial): For experimenting locally on your own computer. This takes some additional setup, but is the best option for working with sensitive data.

To run the code in a cell, click inside it and then press Ctrl + Enter.

*The ZSL model at the core of this notebook runs much faster with graphics processing unit (GPU) acceleration. If you are in Google Colab, you can enable GPU accleration in the settings by going to Runtime > Change runtime type > Hardware accelerator (select "GPU").*

## Example scenario: cannabis legalisation support

For this task, we classify cannabis-related social media posts using two labels: pro-legalisation, and anti-legalisation.

### Cell 0: Install software package requirements

- Pandas is used to load and save data in CSV (comma separated value) format.
- PyTorch and Transformers are used to run the model.
- tqdm is used to show progress bars.

In [None]:
!pip install pandas torch transformers tqdm

### Cell 1: Import essential modules

In [None]:
import os.path
import pandas as pd
import torch
from tqdm.notebook import tqdm
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset

### Cell 2: Load the "cannabis legalisation" dataset

This sample data was generated using ChatGPT, but the methododology presented here works just as well with real data.

In [None]:
# Load the CSV file containing our dataset.
# Here we are giving the URL for a sample file that we've made publicly
# available on the Internet.
data_location = 'https://raw.githubusercontent.com/ltu-capr/zsl-text-tutorial/master/Data/cannabis_legalisation.csv'
dataframe = pd.read_csv(data_location)
input_dataset = KeyDataset(dataframe.to_dict('records'), key='text')

# Display the first 5 text examples from the dataset.
for n, text in zip(range(1, 6), input_dataset):
    print(f'{n}. {text}')
print('   ...')

### Cell 3: Initialise the BART model

Initialise the BART model for use in performing zero-shot classification. It may take a while for the model to download.

In [None]:
# Check to see whether GPU acceleration is available.
if torch.cuda.is_available():
    device = 0
else:
    device = -1

# Initialise the BART model.
model_type = 'facebook/bart-large-mnli'
classifier = pipeline('zero-shot-classification', model=model_type, device=device)

### Cell 4: Initialise classification labels and make model predictions

In order to perform classification we must nominate candidate labels for the model to choose between. In this scenario we have two labels, but you can choose as many labels as you need.

In [None]:
# Here we specify the label options that the model will choose from.
candidate_labels = ['pro-legalisation', 'anti-legalisation']

# Start the classification pipeline.
classifier_outputs = classifier(input_dataset, candidate_labels, batch_size=4)

# Generate prediction results.
all_results = []
for result in tqdm(classifier_outputs, total=len(input_dataset)):
    # Display the first 5 results as the model is running.
    if len(all_results) < 5:
        text = result['sequence']
        labels_with_scores = [
            f'{label} ({score:.2%})'
            for label, score in zip(result['labels'], result['scores'])
        ]
        tqdm.write('')
        tqdm.write(f'Input text:         {text}')
        tqdm.write(f'Model predictions:  {", ".join(labels_with_scores)}')

    # Compile a list of all prediction results.
    all_results.append(result)

### Cell 5: Save the model predictions

This code prepares an output CSV file containing model predictions which can be used for further analysis.

In [None]:
def save_model_predictions(file_name, dataframe, all_results, candidate_labels):
    # Arrange the results in tabular form with neat columns.
    rows = []
    for result in all_results:
        labels = result['labels']
        scores_as_percentages = [round(score * 100, 2) for score in result['scores']]
        row = {'text': result['sequence'], **dict(zip(labels, scores_as_percentages))}
        rows.append(row)
    results_df = pd.DataFrame(rows, columns=['text', *candidate_labels])
    results_df['predicted_label'] = results_df[candidate_labels].idxmax(axis=1)

    # Append the hand-annotated ground truth column (if it exists).
    if 'hand_annotated' in dataframe.columns:
        results_df['hand_annotated'] = dataframe['hand_annotated']

    # Save output to a CSV file.
    os.makedirs('Outputs', exist_ok=True)
    output_file_name = os.path.join('Outputs', file_name)
    results_df.to_csv(output_file_name, index=False)

    try:
        # If we are on Google Colab, download the results.
        from google.colab import files
        files.download(output_file_name)
    except ModuleNotFoundError:
        # If we are not on Google Colab, show the output file location.
        print('Output file saved:')
        print(os.path.abspath(output_file_name))


save_model_predictions('cannabis_legalisation_predictions.csv',
                       dataframe, all_results, candidate_labels)

### Cell 6: Measure the accuracy of model predictions (optional)

Zero-shot learning does not require hand-annotated labels to generate predictions, but they can be use to validate the model's accuracy. Here we compare the model's outputs with hand-annotated (ground truth) labels. If you don't have hand-annotated labels for your data, skip this step.

In [None]:
correct_count = 0
total_count = 0

for result, ground_truth in zip(all_results, dataframe['hand_annotated']):
    total_count += 1
    correct_count += result['labels'][0] == ground_truth

accuracy = correct_count / total_count
print(f'Accuracy: {accuracy:.2%}')

## Try your own analysis

*Ensure that you have run through the "Example scenario" first, as this code makes use of the classifier we initialised in that part of the tutorial.*

### Simple playground

In [None]:
# Try your own example by modifying the input text and candidate labels.

# Put the text that you want the model to classify here.
input_text = [
    'Minimum unit pricing is ridiculous and should be abolished. Big government should not tell me how much a drink should cost.',
]

# Put the options for the model to choose from here.
candidate_labels = [
    'supports minimum unit pricing',
    'does not support minimum unit pricing',
]

classifier(input_text, candidate_labels)

### Bring your own CSV

Instead of entering your input text directly into the code as above, you can instead supply your own data as a CSV file. This makes it easier to experiment with much larger datasets. To do so, ensure that your file has a column called "text" with one example per row. You can create a file like this in Microsoft Excel by saving as a `.csv` file.

When running the following cell, you will be asked to select/input your data file using widgets that appear directly below the cell.

In [None]:
try:
    # If we are on Google Colab, show an upload widget.
    from google.colab import files
    uploaded = files.upload()
    if uploaded:
        data_location = list(uploaded.keys())[0]
    else:
        data_location = ''
except ModuleNotFoundError:
    # If we are not on Google Colab, ask for the name of the file.
    data_location = input('Please enter the name of the file '
                          '(e.g. Data/cannabis_legalisation.csv)\n> ')

# Check whether the file exists.
if not os.path.isfile(data_location):
    print(f'File not found: {data_location}')
    print('Please run this cell again.')

In [None]:
# Read the input data.
dataframe = pd.read_csv(data_location)
input_dataset = KeyDataset(dataframe.to_dict('records'), key='text')

# Show the first 5 text examples.
for n, text in zip(range(1, 6), input_dataset):
    print(f'{n}. {text}')
print('   ...')

In [None]:
# Here we specify the label options that the model will choose from.
# Make sure that you update these options to suit your data and experiment
# with different wordings.
candidate_labels = ['pro-legalisation', 'anti-legalisation']

# Start the classification pipeline.
classifier_outputs = classifier(input_dataset, candidate_labels, batch_size=4)

# Generate prediction results.
all_results = []
for result in tqdm(classifier_outputs, total=len(input_dataset)):
    all_results.append(result)

In [None]:
# Save output to a CSV file.
save_model_predictions('predictions.csv', dataframe, all_results, candidate_labels)