-
Notifications
You must be signed in to change notification settings - Fork 255
/
tasks.py
129 lines (100 loc) · 4.32 KB
/
tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import requests
import json
import concurrent.futures
from abc import ABC, abstractmethod
from typing import List, Dict, Callable
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, classification_report
class DataProcessor(ABC):
def __init__(self, data_dir, max_threads=1):
self.data_dir = data_dir
self.max_threads = max_threads
@abstractmethod
def get_train_examples(self):
pass
@abstractmethod
def get_test_examples(self):
pass
@abstractmethod
def evaluate(self, predictor, test_exs):
pass
@abstractmethod
def stringify_prediction(self, pred):
pass
def process_example(ex, predictor, prompt):
pred = predictor.inference(ex, prompt)
return ex, pred
class ClassificationTask(DataProcessor):
def run_evaluate(self, predictor, prompt, test_exs, n=100):
labels = []
preds = []
texts = []
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_threads) as executor:
futures = [executor.submit(process_example, ex, predictor, prompt) for ex in test_exs[:n]]
for i, future in tqdm(enumerate(concurrent.futures.as_completed(futures)), total=len(futures), desc='running evaluate'):
ex, pred = future.result()
texts.append(ex['text'])
labels.append(ex['label'])
preds.append(pred)
accuracy = accuracy_score(labels, preds)
f1 = f1_score(labels, preds, average='micro')
return f1, texts, labels, preds
def evaluate(self, predictor, prompt, test_exs, n=100):
while True:
try:
f1, texts, labels, preds = self.run_evaluate(predictor, prompt, test_exs, n=n)
break
except (concurrent.futures.process.BrokenProcessPool, requests.exceptions.SSLError):
pass
return f1, texts, labels, preds
class BinaryClassificationTask(ClassificationTask):
categories = ['No', 'Yes']
def stringify_prediction(self, pred):
return BinaryClassificationTask.categories[pred]
class EthosBinaryTask(BinaryClassificationTask):
categories = ['No', 'Yes']
def get_train_examples(self):
df = pd.read_csv(self.data_dir + '/ethos_ishate_binary_shuf.csv', sep=';', header=None)
df = df[(df[1] <= 0) | (df[1] >= 0.7)]
exs = df.reset_index().to_dict('records')
exs = [{'id': x['index'], 'text': x[0], 'label': 1 if x[1] > 0.4 else 0} for x in exs[200:]]
return exs
def get_test_examples(self):
df = pd.read_csv(self.data_dir + '/ethos_ishate_binary_shuf.csv', sep=';', header=None)
df = df[(df[1] <= 0) | (df[1] >= 0.7)]
exs = df.reset_index().to_dict('records')
exs = [{'id': x['index'], 'text': x[0], 'label': 1 if x[1] > 0.4 else 0} for x in exs[:200]]
return exs
class JailbreakBinaryTask(BinaryClassificationTask):
categories = ['No', 'Yes']
def get_train_examples(self):
exs = []
for i, l in enumerate(open(self.data_dir + '/train.tsv')):
convo, label = l.strip().split('\t')
label = int(label)
text = ' '.join([x['text'].strip() for x in json.loads(convo) if x['role'] == 'user'])
exs.append({'id': i, 'text': text, 'label': label})
return exs
def get_test_examples(self):
exs = []
for i, l in enumerate(open(self.data_dir + '/test.tsv')):
convo, label = l.strip().split('\t')
label = int(label)
text = ' '.join([x['text'].strip() for x in json.loads(convo) if x['role'] == 'user'])
exs.append({'id': i, 'text': text, 'label': label})
return exs
class DefaultHFBinaryTask(BinaryClassificationTask):
categories = ['No', 'Yes']
def get_train_examples(self):
exs = []
for i, row in enumerate(open(self.data_dir + '/train.jsonl')):
row = json.loads(row.strip())
exs.append({'id': f'train-{i}', 'label': row['label'], 'text': row['text']})
return exs
def get_test_examples(self):
exs = []
for i, row in enumerate(open(self.data_dir + '/test.jsonl')):
row = json.loads(row.strip())
exs.append({'id': f'test-{i}', 'label': row['label'], 'text': row['text']})
return exs