In [1]:
import torch
from transformers import BartForSequenceClassification, BartTokenizer

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
#DEVICE = "cpu"

print(DEVICE)

class BartZeroShot:
    def __init__(self):
        self.nli_model = BartForSequenceClassification.from_pretrained(
            "facebook/bart-large-mnli"
        ).to(DEVICE)
        self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")

    def predict(self, sentence, label):
        x = self.tokenizer.encode(
            sentence,
            f"This example is {label}",
            return_tensors="pt",
            truncation="only_first",
        ).to(DEVICE)
        logits = self.nli_model(x.to(DEVICE))[0]
        
        entail_contradiction_logits = logits[:, [0, 2]]
        probs = entail_contradiction_logits.softmax(1)
        prob_label_is_true = probs[:, 1].item()
        return prob_label_is_true

cuda:0


In [2]:
bz = BartZeroShot()



In [3]:
bz.predict("I really really hate my life", "positive")

0.0009476275881752372

In [4]:
bz.predict("I really really love my life", "positive")

0.9730832576751709

In [18]:
# load data from s3 bucket
import boto3
import sagemaker

region_name = 'eu-central-1'

session = boto3.Session(region_name=region_name)
s3_sess = session.client('s3')
sm_session = sagemaker.Session(boto_session=session)

In [30]:
!rm -rf ~/.cache/huggingface/datasets


In [31]:
from datasets import load_dataset

input_path = f's3://{sm_session.default_bucket()}/datasets/twitter_ds/'

validation_dataset_s3_path = f"{input_path}/validation/dataset.json"

dataset = load_dataset(
    'json',
    data_files={'validation': validation_dataset_s3_path},
    split='validation',
    download_mode="force_redownload"
)

dataset

Downloading data:   0%|          | 0.00/964k [00:00<?, ?B/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['text', 'label'],
    num_rows: 10000
})

In [32]:
dataset[0]

{'text': '@nicepaul &quot;guerilla&quot; usability tests... nice one ',
 'label': 1}

In [33]:
# Select 60 samples
#subset = dataset.select(range(60))
subset = dataset
# Print the first sample to verify
subset

Dataset({
    features: ['text', 'label'],
    num_rows: 10000
})

In [34]:
def is_positive(text):
    # Call bz.predict to get the probability
    prob = bz.predict(text, "positive")
    
    # Convert the probability to binary (0 or 1)
    return 1 if prob > 0.5 else 0

In [35]:
is_positive("I really really hate my life")

0

In [36]:
from sklearn.metrics import accuracy_score

# Process the dataset
def process_dataset(dataset):
    predictions = []
    true_labels = []
    
    for example in dataset:
        text = example['text']  # Adjust based on your dataset's feature names
        label = example['label']  # Adjust based on your dataset's feature names
        true_labels.append(label)
        
        # Get the prediction
        prediction = is_positive(text)
        predictions.append(prediction)
    
    return predictions, true_labels

# Compute accuracy
def compute_metrics(predictions, true_labels):
    accuracy = accuracy_score(true_labels, predictions)
    return {"accuracy": accuracy}

In [37]:
# Process the dataset
predictions, true_labels = process_dataset(subset)

# Compute and print metrics
metrics = compute_metrics(predictions, true_labels)
print(metrics)

{'accuracy': 0.7075}
