<a href="https://colab.research.google.com/github/hjesse92/style_transfer_w266/blob/main/Few_Shot_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Few Show Learning

## Setup

In [1]:
!pip install -q transformers datasets sentencepiece rouge_score accelerate evaluate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m98.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.0/469.0 KB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m67.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.8/212.8 KB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 KB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 KB[0m [31m21.9 MB/s[0m e

In [2]:
#Am I running a GPU and what type is it?
!nvidia-smi

Fri Mar 10 19:34:06 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    25W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import torch

if torch.cuda.is_available():     
    device = torch.device("cuda")
    print('Number of GPU(s) available:', torch.cuda.device_count())
    print('GPU device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available')
    device = torch.device("cpu")

Number of GPU(s) available: 1
GPU device name: Tesla T4


In [4]:
from logging import warning
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, TensorDataset

# from transformers import BertTokenizer, BertModel

from sklearn.utils import resample
from sklearn.model_selection import train_test_split

import re
import random
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import pprint

import warnings
warnings.filterwarnings('ignore')

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7fec50803cb0>

In [7]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
!cd drive/MyDrive/data

In [6]:
train_file = '/content/drive/MyDrive/data/original-train.tsv'
dev_file = '/content/drive/MyDrive/data/original-dev.tsv'
test_file = '/content/drive/MyDrive/data/original-test.tsv'
df_train = pd.read_csv(train_file, sep='\t')
df_dev = pd.read_csv(dev_file, sep='\t')
df_test = pd.read_csv(test_file, sep='\t')

In [7]:
print(f'''mean length of offensive text: {df_train['offensive-text'].map(len).mean()}''')
print(f'''min length of offensive text: {df_train['offensive-text'].map(len).min()}''')
print(f'''max length of offensive text: {df_train['offensive-text'].map(len).max()}''')
print(f'''mean length of neutralized text: {df_train['style-transferred-text'].map(len).mean()}''')
print(f'''min length of neutralized text: {df_train['style-transferred-text'].map(len).min()}''')
print(f'''max length of neutralized text: {df_train['style-transferred-text'].map(len).max()}''')

mean length of offensive text: 69.85353535353535
min length of offensive text: 9
max length of offensive text: 238
mean length of neutralized text: 60.48800505050505
min length of neutralized text: 1
max length of neutralized text: 174


# Trial with Flan T5

In [10]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

t5tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto", torch_dtype=torch.float16)

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [49]:
#@title 3-Shot Learning
df_fewshot = df_test.copy()
n_shots = 3

for row in range(len(df_fewshot)):
    #For each item in the test set, we use 3 examples from the training set to as n-shots
    sources_targets = df_train.sample(n_shots, replace=False)

    for i in range(n_shots):
        df_fewshot.loc[row, f'shot{i+1}_source'] = sources_targets.iloc[i,0]
        df_fewshot.loc[row, f'shot{i+1}_target'] = sources_targets.iloc[i,1]


df_fewshot = df_fewshot.rename(columns={'offensive-text':'source', 'style-transferred-text':'target'})

In [52]:
df_fewshot['prompt'] = df_fewshot.apply(lambda x: 
                 'Rewrite the toxic text in non-toxic style. \n\n'
                 'Toxic text: ' + x['shot1_source'] + '\n' + 'Non-toxic text: ' + x['shot1_target'] + '\n\n' + \
                 'Toxic text: ' + x['shot2_source'] + '\n' + 'Non-toxic text: ' + x['shot2_target'] + '\n\n' + \
                 'Toxic text: ' + x['shot3_source'] + '\n' + 'Non-toxic text: ' + x['shot3_target'] + '\n\n' + \
                 'Toxic text: ' + x['source'] + '\n' + 'Non-toxic text: ', axis=1)

In [53]:
print(df_fewshot.prompt[0])

Rewrite the toxic text in non-toxic style. 

Toxic text: Another totally inaccurate post. But gullible people comment on this bs
Non-toxic text: Another totally inaccurate post. But gullible people comment on this kind of nonsense

Toxic text: You're an illiterate fool.
Non-toxic text: You're not being smart

Toxic text: Only an idiot thinks giving up your right to own a firearm is a good idea. Liberals=fucking pussies. Go suck betos dick ya fuck.
Non-toxic text: No smart person things giving up your right to own a firearm is a good idea. Liberals are cowards

Toxic text: So maybe you should be more retarded.
Non-toxic text: 


In [54]:
for i in range(len(df_fewshot)):
    if i%10 == 0:
        print(f'Working on number: {i}')
    t5_inputs = t5tokenizer([df_fewshot.prompt[i]], return_tensors='pt').input_ids.to("cuda")
    t5_summary_ids = t5model.generate(t5_inputs, 
                                    num_beams=3,
                                    no_repeat_ngram_size=1,
                                    num_return_sequences=1,
                                    top_p = 0.9,
                                    top_k = 50,
                                    min_length = 1,
                                    max_length = 200,
                                    do_sample=True,
                                    )
    output = t5tokenizer.batch_decode(t5_summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    df_fewshot.loc[i,'fewshot_output'] = output

Working on number: 0
Working on number: 10
Working on number: 20
Working on number: 30
Working on number: 40
Working on number: 50
Working on number: 60
Working on number: 70
Working on number: 80
Working on number: 90
Working on number: 100
Working on number: 110
Working on number: 120
Working on number: 130
Working on number: 140
Working on number: 150
Working on number: 160
Working on number: 170
Working on number: 180
Working on number: 190


In [55]:
df_fewshot.loc[:,['source','target','prompt','fewshot_output']]

Unnamed: 0,source,target,prompt,fewshot_output
0,So maybe you should be more retarded.,So maybe you should be more backward,Rewrite the toxic text in non-toxic style. \n\...,So maybe you should be more retarded.
1,THERES A MEGATHREAD FOR VACCINE OR COVID RELAT...,THERES ACTUALLY A MEGATHREAD FOR VACCINE OR CO...,Rewrite the toxic text in non-toxic style. \n\...,THERE IS A MEGATHREAD FOR VACCINE OR COVIDE RE...
2,the fuck.... you are on drugs,uh..... you are not being realistic,Rewrite the toxic text in non-toxic style. \n\...,the fuck.... you are on drugs
3,NTA. Dump his ass. Tablemanners are no rocke...,You should leave him. Tablemanners are no rock...,Rewrite the toxic text in non-toxic style. \n\...,Tablemanners are no rocket science. Treating o...
4,Youre soft as baby shit,Youre really soft,Rewrite the toxic text in non-toxic style. \n\...,You're soft as baby shit
...,...,...,...,...
194,"NTA, Jes-us your brother is dumb. How long did...","NTA, your brother is not thinking straight. Ho...",Rewrite the toxic text in non-toxic style. \n\...,Jes-us your brother is dumb. How long did he t...
195,Formula one…nascar….shit it exactly the same e...,Formula one…nascar….it exactly the same except...,Rewrite the toxic text in non-toxic style. \n\...,Formula one...nascar....shit it exactly the sa...
196,"BRB, gonna go call the mods pathetic egotistic...","BRB, gonna go call the mods.",Rewrite the toxic text in non-toxic style. \n\...,"BRB, gonna go call the mods stupid neck beard."
197,CUCKOLD Carlson is a problem,Carlson is the problem,Rewrite the toxic text in non-toxic style. \n\...,Carlson is a problem


In [56]:
df_fewshot.to_csv('test_output.csv',sep='\t',index=False)

## Evaluation

In [33]:
import evaluate

rouge = evaluate.load('rouge')
bleu = evaluate.load('bleu')
# bleurt = evaluate.load('bleurt')

In [47]:
#@title Baseline Score on the source and target
print(rouge.compute(predictions=df_fewshot.source,
              references=df_fewshot.target))
print(bleu.compute(predictions=df_fewshot.source,
              references=df_fewshot.target))

# If my predictions did nothing but repeat the same toxic text, I'd get these scores

{'rouge1': 0.6881033658473039, 'rouge2': 0.5580564200055431, 'rougeL': 0.6836240456794466, 'rougeLsum': 0.6836978988160836}
{'bleu': 0.5391232310503405, 'precisions': [0.6839945280437757, 0.5702752293577982, 0.49604117181314333, 0.436613665663945], 'brevity_penalty': 1.0, 'length_ratio': 1.1498230436492332, 'translation_length': 2924, 'reference_length': 2543}


In [48]:
#@title Score after few shot learning
print(rouge.compute(predictions=df_fewshot.fewshot_output,
              references=df_fewshot.target))
print(bleu.compute(predictions=df_fewshot.fewshot_output,
              references=df_fewshot.target))

{'rouge1': 0.6213289037367359, 'rouge2': 0.4787594491047313, 'rougeL': 0.6158645201277622, 'rougeLsum': 0.6151169976658839}
{'bleu': 0.44473163805572996, 'precisions': [0.6275704493526276, 0.4824886691388546, 0.39676840215439857, 0.32561576354679805], 'brevity_penalty': 1.0, 'length_ratio': 1.0326386158081007, 'translation_length': 2626, 'reference_length': 2543}


In [59]:
#@title Tune the model with training set, then do few-shot learning again

from transformers import TrainingArguments
from datasets import load_dataset

training_args = TrainingArguments(output_dir='tuned_t5_toxic')

# Metric
metric = evaluate.load("rouge")

# helper function to postprocess text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    return result

In [60]:
from transformers import DataCollatorForSeq2Seq
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    t5tokenizer,
    model=t5model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)

In [12]:
df_train

Unnamed: 0,offensive-text,style-transferred-text
0,Pussy nobody asked for your input.,Nobody asked for your input
1,"Ok, this makes no sense. This will create vigi...","Ok, this makes no sense. This will create vigi..."
2,so fucking true. the amount of up and coming r...,so true. the amount of up and coming rappers t...
3,Go f yourself Republican scum who put us here ...,Republicans put us in this situation. I would ...
4,Dumb fucking take. People want to do things.,That's not a smart take. People want to do thi...
...,...,...
1579,The View! And this crap hits my front page!? l...,"This must be mostly bots but still, it's stran..."
1580,That’s racist.,Those actions are racist.
1581,Cultural Marxism isn't a thing you weird fuck.,Cultural Marxism isn't a thing.
1582,"LOL, anyone that questions the Democrat progra...","LOL, anyone that questions the Democrat progra..."


In [None]:
from huggingface_hub import HfFolder
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Hugging Face repository id
# repository_id = f"{model_id.split('/')[1]}-{dataset_id}"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir='t5_tuned',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    fp16=False, # Overflows with fp16
    learning_rate=5e-5,
    num_train_epochs=5,
    # logging & evaluation strategies
    logging_dir=f"{repository_id}/logs",
    logging_strategy="steps",
    logging_steps=500,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    # metric_for_best_model="overall_f1",
    # push to hub parameters
    report_to="tensorboard",
    push_to_hub=False,
    hub_strategy="every_save",
    hub_model_id=repository_id,
    hub_token=HfFolder.get_token(),
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=t5model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    compute_metrics=compute_metrics,
)

In [62]:
from datasets import load_dataset

dataset = load_dataset('yelp_review_full')

Downloading builder script:   0%|          | 0.00/4.41k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.04k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.55k [00:00<?, ?B/s]

Downloading and preparing dataset yelp_review_full/yelp_review_full to /root/.cache/huggingface/datasets/yelp_review_full/yelp_review_full/1.0.0/e8e18e19d7be9e75642fc66b198abadb116f73599ec89a69ba5dd8d1e57ba0bf...


Downloading data:   0%|          | 0.00/196M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/650000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset yelp_review_full downloaded and prepared to /root/.cache/huggingface/datasets/yelp_review_full/yelp_review_full/1.0.0/e8e18e19d7be9e75642fc66b198abadb116f73599ec89a69ba5dd8d1e57ba0bf. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [22]:
x_train = df_train['offensive-text'].map(tokenize_function)

In [24]:
x_train[0]

{'input_ids': [5004, 7, 7, 63, 12638, 1380, 21, 39, 3785, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [67]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))



In [68]:
class ToxicData(Dataset):
    def __init__(self, data, tokenizer):
        super(Dataset, self).__init__()
        self.source_texts = data['offensive-text']
        self.target_texts = data['style-transferred-text']
        self.tokenizer=tokenizer

    def __len__(self):
        return len(self.target_texts)

    def __getitem__(self, idx):
        source_text = self.source_texts[idx]
        target_text = self.target_texts[idx]

        source_tokens = self.tokenizer.encode(source_text, padding='max_length',truncation=True)
        source_tensor = torch.tensor(source_tokens, dtype=torch.long)

        target_tokens = self.tokenizer.encode(target_text, padding='max_length',truncation=True)
        target_tensor = torch.tensor(target_tokens, dtype=torch.long)

        if torch.cuda.is_available():
            source_tensor = source_tensor.cuda()
            target_tensor = target_tensor.cuda()

        return source_tensor, target_tensor

In [62]:
def tokenize_function(examples):
    return t5tokenizer(examples, padding="max_length", truncation=True)['input_ids']

In [69]:
train_data = ToxicData(df_train, t5tokenizer)

In [70]:
train_loader = DataLoader(train_data, batch_size=8)

In [71]:
next(iter(train_loader))

[tensor([[ 5004,     7,     7,  ...,     0,     0,     0],
         [ 8872,     6,    48,  ...,     0,     0,     0],
         [   78,     3,    89,  ...,     0,     0,     0],
         ...,
         [  377,  4636,   150,  ...,     0,     0,     0],
         [  100,    19,   310,  ...,     0,     0,     0],
         [ 9758,   231, 26557,  ...,     0,     0,     0]], device='cuda:0'),
 tensor([[22009,  1380,    21,  ...,     0,     0,     0],
         [ 8872,     6,    48,  ...,     0,     0,     0],
         [   78,  1176,     5,  ...,     0,     0,     0],
         ...,
         [  150,     1,     0,  ...,     0,     0,     0],
         [  100,    19,   310,  ...,     0,     0,     0],
         [ 9758,    25,   278,  ...,     0,     0,     0]], device='cuda:0')]