/
translate_data.py
211 lines (170 loc) · 6.63 KB
/
translate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import warnings
from datasets.utils.logging import set_verbosity_error
import time
import sys
import os
import shutil
import pandas as pd
from datasets import load_dataset
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
from transformers.onnx import FeaturesManager
import torch.multiprocessing as mp
from pathlib import Path
import captionsathome as cah
from transformers import pipeline
import torch
set_verbosity_error()
warnings.filterwarnings("ignore")
TMP_DATA_DIR = 'DATA'
FINAL_DATA_DIR = 'DATA_'
if not os.path.exists(TMP_DATA_DIR):
os.makedirs(TMP_DATA_DIR)
if not os.path.exists(FINAL_DATA_DIR):
os.makedirs(FINAL_DATA_DIR)
if torch.cuda.is_available():
num_devices = torch.cuda.device_count()
torch.cuda.empty_cache()
else:
assert False, "No GPU available"
def translate(text, tokenizer, model, device):
'''
example:Dict[List] with len==batch_size
tokenizer:transformers.PreTrainedTokenizer
model:transformers.PreTrainedModel
returns:
example with added column ENG TEXT
'''
# text = example["TEXT"]
torch.cuda.empty_cache()
# translator = pipeline(task="translation_xx_to_yy", model=model, tokenizer=tokenizer)
with torch.no_grad():
# eng_text = translator(
# text,
# return_tensors="pt",
# clean_up_tokenization_spaces=True,
# tgt_lang="en"
# )
encoded = tokenizer(text, return_tensors="pt",
padding=True ).to(device)
generated_tokens = model.generate(
**encoded, forced_bos_token_id=tokenizer.get_lang_id("en")
).to(device)
eng_text = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True)
torch.cuda.empty_cache()
return eng_text
def process_data(d, name, folder, tokenizer, model, device):
'''
d:dataset that we want to process
name:str name of the parquet file (in our case language)
folder:str folder to save processed data (different folder for each process)
'''
d = d.map(lambda example: {'ENG TEXT':translate(example['TEXT'], tokenizer, model, device)},
batched=True, batch_size=10)
# d = d.add_column('ENG TEXT', d_eng)
d.to_parquet(f'{folder}/{name}.parquet', batch_size=100)
def translate_part(tokenizer, model, start, end, small_dataset,device, DATA_DIR):
'''
Translates part of the desired dataset.
tokenizer:transformers.PreTrainedTokenizer
model:transformers.PreTrainedModel
start:int start of the range that dataset is splited
end:int end of the range that dataset is splited
small_dataset: part of the dataset that is being used
'''
directory = DATA_DIR + f'/{start}-{end}'
if not os.path.exists(directory):
os.makedirs(directory)
part_dataset = small_dataset.select([i for i in range(start, end, 1)])
# get all languages that are presented in the dataset
langs = part_dataset.unique('LANGUAGE')
# go through all the languages
for lang in langs:
try:
tokenizer.src_lang = lang
process_data(part_dataset.filter(
lambda example: example["LANGUAGE"] == lang),
lang,
directory,
tokenizer,
model,
device
)
# The model doeasn't support some languages
except KeyError:
continue
torch.cuda.empty_cache()
# combine all parquet files into one
fs = [pd.read_parquet(directory+'/'+path)
for path in os.listdir(directory)]
pd.concat(fs).to_parquet(f'{DATA_DIR}/{start}-{end}.parquet')
# remove the directory that is not needed anymore
shutil.rmtree(directory)
def worker():
client = cah.init(
url="http://cah.io.community/",
device_id="cluster"
)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B", use_fast=True)
# target language
tokenizer.tgt_lang = "en"
while client.jobCount() > 0 and not client.shouldDie():
client.newJob()
client.log("Processing...")
client.log(client.tar_url)
url, partition = client.tar_url.split("parquet:")
url += 'parquet'
url = url.split('/')[-1]
print(url)
DATA_DIR = TMP_DATA_DIR + f'/{url}'
dataset = load_dataset("laion/laion2B-multi-joined", data_files=url, split="train")
shard = dataset.shard(
360, int(partition), contiguous=True, keep_in_memory=True)
s = time.time()
# to make parallel GPU computations possible
try:
mp.set_start_method('spawn', force=True)
except RuntimeError:
pass
# number of processes depends on the number of GPUs available, can be varied
num_proc_per_device = 3
num_processes = num_devices*num_proc_per_device
device_ids = [i for i in range(num_devices)]*num_proc_per_device
processes = []
n_samples = len(shard)
c = int(n_samples/num_processes)
ranges = [[_, _+c] for _ in range(0, n_samples, c)]
for rank, rng, device_id in zip(range(num_processes), ranges, device_ids):
device = torch.device(f"cuda:{device_id}")
print(f'Using: {device}')
model = M2M100ForConditionalGeneration.from_pretrained(
"facebook/m2m100_1.2B").to(device)
model = model.half()
nodel = model.eval()
model.share_memory()
start, end = rng
p = mp.Process(
target=translate_part,
args=[tokenizer, model, start, end, dataset, device, DATA_DIR]
)
p.start()
processes.append(p)
for p in processes:
p.join()
# combining parquet files for every process into one file
fs = [pd.read_parquet(DATA_DIR+'/'+path)
for path in os.listdir(DATA_DIR)]
DATA_DIR_TRANSLATED = f'{FINAL_DATA_DIR}/{url}'
if not os.path.exists(DATA_DIR_TRANSLATED):
os.makedirs(DATA_DIR_TRANSLATED)
pd.concat(fs).to_parquet(
f'{DATA_DIR_TRANSLATED}/{url}_{partition}.parquet')
# remove the directories we don't need anymore
shutil.rmtree(DATA_DIR)
e = time.time()
print(f'Processed in {round(e-s, 2)} seconds')
client.completeJob()
return 0
if __name__ == '__main__':
exit(worker())
client.bye()