### Title: Continuous data for continuous fine-tuning

Copyright (c) 2024 Praneeth Vadlapati

In [25]:
import io
import os
import time
from urllib.parse import urlparse

from datasets import load_dataset, get_dataset_config_names
from dotenv import load_dotenv
from groq import Groq
import pandas as pd
import replicate
# import requests
from tavily import TavilyClient


load_dotenv()

dataset_name = 'HuggingFaceFW/fineweb'

data_name = dataset_name.split('/')[-1]  # get part after last /
folder = 'data'

# latest_dump_name = 'CC-MAIN-2024-18'  # set manually for now
latest_dump_file = os.path.join(folder, f'{data_name}-latest_dump.txt')
with open(latest_dump_file, 'r') as f:
	latest_dump_name = f.read().strip()

if 'latest_dump_name' not in globals() or not latest_dump_name:
	print('Fetching latest dump name...')
	versions = get_dataset_config_names(dataset_name)
	versions = [v for v in versions if v != 'default' and not v.startswith('sample')]
	latest_dump_name = sorted(versions, key=lambda x: x, reverse=True)[0]

# save latest dump name to file
with open(latest_dump_file, 'w') as f:
	f.write(latest_dump_name)

# create folder for saving csv files
data_dir = os.path.join(folder, f'{data_name}-{latest_dump_name}')
os.makedirs(data_dir, exist_ok=True)

ext = 'csv'

# To process flagging existing file instead of fetching new data
PROCESS_EXISTING_FILE = True

### Collect the data

In [26]:
def get_filename(index, process_type='full'):
	if process_type:
		process_type = f'.{process_type}'
	return os.path.join(data_dir, f'New_data - {index}{process_type}.{ext}')


new_data_filename = None
skip_index = 0
last_existing_file_index = -1
for index in range(1000):
	new_data_filename = get_filename(index)
	if os.path.exists(new_data_filename) and os.stat(new_data_filename).st_size > 0:
		last_existing_file_index = index
		skip_index += len(pd.read_csv(new_data_filename))
	else:
		flagged_data_filename = get_filename(index, 'flagged')
		filtered_data_filename = get_filename(index, 'filtered')
		short_text_filename = get_filename(index, 'shortened')
		break  # found file that doesn't exist or is empty


if PROCESS_EXISTING_FILE and last_existing_file_index != -1:
	index = last_existing_file_index
	new_data_filename = get_filename(index)
	flagged_data_filename = get_filename(index, 'flagged')
	filtered_data_filename = get_filename(index, 'filtered')
	short_text_filename = get_filename(index, 'shortened')
	full_df = pd.read_csv(new_data_filename)
else:
	# if the file exists, load it
	start_time = time.time()
	dataset = load_dataset(dataset_name, name=latest_dump_name, split='train', streaming=True)
	if skip_index:
		dataset = dataset.skip(skip_index)

	limit = 100
	data = []
	curr_size = 0
	for i, item in enumerate(dataset):
		if curr_size >= limit:
			break
		data.append(item)
		curr_size += 1

	time_min, time_sec = divmod(time.time() - start_time, 60)
	print(f'Time taken: {time_min:.0f} min {time_sec:.0f} sec')

	full_df = pd.DataFrame(data)
	full_df = full_df[full_df['language'] == 'en']  # drop rows with language not 'en'
	columns_to_remove = ['dump', 'date', 'file_path', 'language',
						'token_count', 'language_score', 'filter_reason']
	full_df.drop(columns=columns_to_remove, inplace=True, errors='ignore')

	if not full_df.empty:
		full_df.to_csv(new_data_filename, index=False)


print('New data file:', os.path.basename(new_data_filename))
print('Flagged data file:', os.path.basename(flagged_data_filename))
full_df.head(2)

New data file: New_data - 0.full.csv
Flagged data file: New_data - 0.flagged.csv


Unnamed: 0,text,id,url
0,We want to know how to best serve you. Please ...,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback
1,Architectural Control Committee Policies and F...,<urn:uuid:77695799-0774-42a1-8eaa-5efbe154c4e0>,http://aberdeencreekfl.com/ACCBusiness/Procedu...


### Flagging the data:

#### Essential functions

In [27]:
safe_flag = 'safe'

def is_unsafe(text, prompt='', max_retries=3):
	'Ask LlamaGuard if a text is safe'
	for _ in range(max_retries):
		try:
			if not len(text):
				return True

			output = replicate.run(
				'meta/meta-llama-guard-2-8b:b063023ee937f28e922982abdbf97b041ffe34ad3b35a53d33e1d74bb19b36c4',
				input={ 'prompt': prompt, 'assistant': text },
			)
			if isinstance(output, str):
				if output == safe_flag:
					return False
				else:
					output = output.replace('unsafe', '').strip()
					return output  # returns reason like 'S5', 'S6', etc.
			else:
				raise Exception('Invalid output')
		except Exception as e:
			print(f'Error: {e}. Retrying')
			time.sleep(1)


flagged_df = None

# if the file exists, load it
if os.path.exists(flagged_data_filename):
	flagged_df = pd.read_csv(flagged_data_filename)


if flagged_df is None or flagged_df.empty:
	# if flagged_df is None or flagged_df.empty:
	flagged_df = full_df.copy()
	flagged_df = flagged_df.dropna(subset=['text'])  # Filter out rows with empty text


def save_flagged_df(df=flagged_df):
	df.to_csv(flagged_data_filename, index=False)


tavily = TavilyClient(api_key=os.getenv('TAVILY_API_KEY'))

def is_indexed(domain, max_retries=3):
	'Search online to check if a domain is indexed'
	for _ in range(max_retries):
		try:
			result = tavily.search(query=f'site:{domain}')
			if not result or 'results' not in result:
				raise Exception('No results')
			indexed = len(result['results']) > 0
			if indexed is not None:
				return True if indexed else False
		except replicate.exceptions.ReplicateError:  # rate limit
			print('Rate limit exceeded. Waiting...')
			time.sleep(10)
		except Exception as e:
			print(f'Error: {e}. Retrying')
			time.sleep(1)  # wait for a second before retrying

def print_progress():
	print('.', end='', flush=True)

def print_error():
	print('!', end='', flush=True)

#### Unsafe text

In [28]:
def flag_unsafe_text(flagged_df):
	if 'text_unsafe' not in flagged_df.columns:
		flagged_df['text_unsafe'] = None

	# Identify rows with empty 'text_unsafe' value
	indices = flagged_df[flagged_df['text_unsafe'].isna()].index
	for i in indices:
		flagged_df.loc[i, 'text_unsafe'] = is_unsafe(flagged_df.loc[i, 'text']) or safe_flag
		print_progress()

	if not indices.empty:
		save_flagged_df()

# replace NA in text_unsafe with safe_flag
flagged_df['text_unsafe'] = flagged_df['text_unsafe'].fillna(safe_flag)

flag_unsafe_text(flagged_df)
print(f'Flagged data size: {flagged_df.shape}')
flagged_df.head(2)

Flagged data size: (100, 7)


Unnamed: 0,text,id,url,text_unsafe,domain_unsafe,domain_unindexed,flags
0,We want to know how to best serve you. Please ...,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback,safe,,,safe
1,Architectural Control Committee Policies and F...,<urn:uuid:77695799-0774-42a1-8eaa-5efbe154c4e0>,http://aberdeencreekfl.com/ACCBusiness/Procedu...,safe,True,,safe


#### Unsafe domains

In [29]:
# use tavily api to find whether google indexes the page or not

def get_main_domain(domain):
	'Sometimes, only subdomains are provided. This function returns the main domain.'
	'Example: news.example.com -> example.com, www.example.co.uk -> example.co.uk, www.example.ac.ir -> example.ac.ir'
	parts = domain.split('.')
	# Handle cases with ccTLDs or multi-part TLDs (like co.uk, ac.ir)
	if len(parts) >= 3 and (parts[-2] in ['co', 'ac'] or len(parts[-1]) == 2):
		return '.'.join(parts[-3:])
	else:
		return '.'.join(parts[-2:])


def flag_unsafe_domains(flagged_df):
	'Flag domains that are unsafe or unindexed, as unsafe domains ' \
    'are less likely to be indexed by a search engine'
	urls = flagged_df['url'].dropna().unique()
	domains = [urlparse(url).netloc for url in urls]
	# get main domains from subdomains
	domains = [get_main_domain(domain) for domain in domains]
	domains = list(set(domains))  # remove duplicates

	if 'domain_unsafe' not in flagged_df.columns:
		unsafe_domains = set()
		for domain in domains:
			if is_unsafe(domain):
				unsafe_domains.add(domain)
				print_progress()
		print(f'Unsafe domains: {unsafe_domains}')

		flagged_df['domain_unsafe'] = None

		for i, row in flagged_df.iterrows():
			if row['domain_unsafe'] is None:
				domain = urlparse(row['url']).netloc
				flagged_df.loc[i, 'domain_unsafe'] = domain in unsafe_domains

		save_flagged_df()


	if 'domain_unindexed' not in flagged_df.columns:
		unindexed_domains = set()
		for domain in domains:
			if not is_indexed(domain):
				unindexed_domains.add(domain)
				print_progress()
		print(f'Unindexed domains: {unindexed_domains}')

		flagged_df['domain_unindexed'] = None

		for i, row in flagged_df.iterrows():
			if row['domain_unindexed'] is None:
				domain = urlparse(row['url']).netloc
				flagged_df.loc[i, 'domain_unindexed'] = domain in unindexed_domains

		save_flagged_df()


flag_unsafe_domains(flagged_df)
print(f'Flagged data size: {flagged_df.shape}')
flagged_df.head(2)

Flagged data size: (100, 7)


Unnamed: 0,text,id,url,text_unsafe,domain_unsafe,domain_unindexed,flags
0,We want to know how to best serve you. Please ...,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback,safe,,,safe
1,Architectural Control Committee Policies and F...,<urn:uuid:77695799-0774-42a1-8eaa-5efbe154c4e0>,http://aberdeencreekfl.com/ACCBusiness/Procedu...,safe,True,,safe


#### Sensitive Topics and Biased data

In [30]:
groq_client = Groq()

flags_list = [
	'sensitive_topic', 'biased', 'religious', 'lottery', 'scam', 
	'advertisement', # 'cheating_service', 'unethical', 
	'adversarial_attack_data_poisoning',
]
flags_to_detect = ', '.join(flags_list)

def ask_bot(df, indices_list, max_retries=3):
	csv_text = df[['id', 'text']].iloc[indices_list].to_csv(index=False)
	prompt = (
		'You are a content moderator. \n '
		f'Fill `flags` column that contains one or more of flags to detect: `{flags_to_detect}` '
		'in a single string separated by commas, or leave it blank. '
		'Return csv file back in triple backticks. \n'
		f'\n\n ```\n{csv_text}\n``` \n\n'
		'Output columns (to be strictly followed): `id`,`flags` \n'
		'Return only csv response and no other text'
	)
	for _ in range(max_retries):
		try:
			chat_completion = groq_client.chat.completions.create(
				messages=[{ 'role': 'user', 'content': prompt }],
				model=os.getenv('GROQ_MODEL'),
			)
			response = chat_completion.choices[0].message.content
			# replace single backticks with triple backticks
			if '```' not in response:
				response = response.replace('`', '```')
			response = response.replace('```\n```', '```')
			# get the value from triple backticks
			response = response.split('```')[1].strip()

			if response.startswith('csv'):  # remove 'csv' from start
				response = response[3:].strip()
			df = pd.read_csv(io.StringIO(response))

			df = df[['id', 'flags']]  # use only these columns
			if df.empty:
				return None
			return df
		except Exception as e:
			print_error()
			# print(f'Error: {e}. Retrying')
			# time.sleep(1)
	raise Exception('Bot response failed')


def flag_with_LLM(flagged_df):
	if 'flags' not in flagged_df.columns:
		flagged_df['flags'] = None

	noflags_indices = flagged_df[flagged_df['flags'].isna()].index

	# split indices into chunks
	chunk_size = 4
	chunks = [noflags_indices[i:i + chunk_size] for i in range(0, len(noflags_indices), chunk_size)]
	print(f'Total chunks: {len(chunks)}')

	for chunk in chunks:
		flags_df = ask_bot(flagged_df, indices_list=chunk, max_retries=5)
		flags_df['flags'] = flags_df['flags'].fillna(safe_flag)
		# for each id in flags_df, update the flags column in flagged_df
		# for all listed indices, mark as safe
		for i in chunk:
			flagged_df.loc[i, 'flags'] = safe_flag
		for i, row in flags_df.iterrows():
			flagged_df.loc[flagged_df['id'] == row['id'], 'flags'] = row['flags']
		print_progress()

	# set as None where flags is False, '' or value not in flags_list
	flagged_df['flags'] = flagged_df['flags'].apply(lambda x: x if x in flags_list else safe_flag)
	save_flagged_df()

flag_with_LLM(flagged_df)
flagged_df.head(2)

Total chunks: 0


Unnamed: 0,text,id,url,text_unsafe,domain_unsafe,domain_unindexed,flags
0,We want to know how to best serve you. Please ...,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback,safe,,,safe
1,Architectural Control Committee Policies and F...,<urn:uuid:77695799-0774-42a1-8eaa-5efbe154c4e0>,http://aberdeencreekfl.com/ACCBusiness/Procedu...,safe,True,,safe


### Filtering using flagged data

In [31]:
filtered_df = flagged_df.copy()

# drop where the column values are True
columns_to_remove = ['text_unsafe', 'domain_unsafe', 'domain_unindexed', 'flags']

# replace some strings with booleans
filtered_df[columns_to_remove] = filtered_df[columns_to_remove].replace({
	'False': None, 'false': None, False: None,
	'True': True, 'true': True,
	'safe': None, 
})
save_flagged_df()

# Print value counts for text_unsafe column and transform using harm categories
harm_categories = {
	'S1': 'Violent Crimes',
	'S2': 'Non-Violent Crimes',
	'S3': 'Sex-Related Crimes',
	'S4': 'Child Sexual Exploitation',
	'S5': 'Specialized Advice',
	'S6': 'Privacy',
	'S7': 'Intellectual Property',
	'S8': 'Indiscriminate Weapons',
	'S9': 'Hate',
	'S10': 'Suicide & Self-Harm',
	'S11': 'Sexual Content'
}
unsafe_count = filtered_df['text_unsafe'].str.split(',').explode().str.strip().value_counts()
unsafe_count_transformed = unsafe_count.rename(index=harm_categories)
print(unsafe_count_transformed)
print('')

# print count of each value in flags. if it has multiple values, take first value
flags_count = filtered_df['flags'].str.split(',').explode().str.strip().value_counts()
print(flags_count)
print('')

removal_reason_data = {}  # 'text_unsafe': 10, ...
removed_rows = 0

# drop if any flag is not None
# filtered_df = filtered_df[~filtered_df[columns_to_remove].any(axis=1)]
# filter using each column
for column in columns_to_remove:
	removal_count = filtered_df[column].notna().sum()
	removed_rows += removal_count
	print(f'{column}: {removal_count}')
	filtered_df = filtered_df[filtered_df[column].isna()]
filtered_df.drop(columns=columns_to_remove, inplace=True)
filtered_df.reset_index(drop=True, inplace=True)
print(f'Removed rows: {flagged_df.shape[0] - filtered_df.shape[0]} of {flagged_df.shape[0]}')
print(f'Retained rows: {filtered_df.shape[0]}')

filtered_df.to_csv(filtered_data_filename, index=False)
filtered_df.head(2)

text_unsafe
Specialized Advice    4
Sexual Content        3
Sex-Related Crimes    1
Privacy               1
Non-Violent Crimes    1
Name: count, dtype: int64

flags
sensitive_topic                      6
advertisement                        6
adversarial_attack_data_poisoning    3
biased                               2
lottery                              1
scam                                 1
Name: count, dtype: int64

text_unsafe: 8
domain_unsafe: 3
domain_unindexed: 5
flags: 16
Removed rows: 32 of 100
Retained rows: 68


Unnamed: 0,text,id,url
0,We want to know how to best serve you. Please ...,<urn:uuid:faff9b64-041c-4b98-8be4-7ff2a02e4b8d>,http://38.paulosimoes.net/forms/feedback
1,Welcome to AnnieMation’s webpage. We are an in...,<urn:uuid:76d0f406-290e-41c9-a4bc-2062e6fc6296>,http://anniemationclog.co.uk/


### _Optional_: Optimize text for fine-tuning

In [32]:
columns_to_keep = ['text', 'id']

# consider only columns - text
if os.path.exists(short_text_filename):
	short_text_df = pd.read_csv(short_text_filename)
	new_filtered_df = filtered_df[columns_to_keep].copy()

	# keep rows that are in filtered_df and remove others
	short_text_df = short_text_df[short_text_df['id'].isin(new_filtered_df['id'])]
	# add missing rows from filtered_df
	missing_rows = new_filtered_df[~new_filtered_df['id'].isin(short_text_df['id'])]
	if not missing_rows.empty:
		short_text_df = pd.concat([short_text_df, missing_rows])
		short_text_df = short_text_df.drop_duplicates(subset='id')
else:
	short_text_df = filtered_df[columns_to_keep].copy()
	short_text_df['finetune_text'] = None

# take each row and ask groq to shorten the text and make it suitable for fine-tuning dataset
shortener_prompt_template = (
	'You are a content moderator who is preparing a dataset for fine-tuning a language model. '
	'You have a text that needs to be shortened and made suitable for the dataset. \n'
	'Return the optimized text in the triple backticks. \n'
	'\n\n Initial text: ```\n{initial_text}\n``` \n\n'
)

def get_shorter_text(text, max_retries=3):
	for _ in range(max_retries):
		try:
			chat_completion = groq_client.chat.completions.create(
				messages=[
					{ 'role': 'user', 'content': shortener_prompt_template.format(initial_text=text) }
				],
				model=os.getenv('GROQ_MODEL'),
			)
			response = chat_completion.choices[0].message.content
			# replace single backticks with triple backticks
			if '```' not in response:
				response = response.replace('`', '```')
			response = response.replace('```\n```', '```')
			# get the value from triple backticks
			response = response.split('```')[1].strip()
			if response:
				return response
			else:
				raise Exception('Empty response')
		except Exception as e:
			print(f'Error: {e}. Retrying')
			time.sleep(1)

def shorten_text_df(df):
	# if text column is None, get shortened text using initial text
	initial_length_sum = 0
	shortened_length_sum = 0
	for i, row in df.iterrows():
		if pd.isna(row['finetune_text']):
			shorter_text = get_shorter_text(row['text'])
			if not shorter_text or len(shorter_text) > len(row['text']):
				print_error()
				continue
			df.loc[i, 'finetune_text'] = shorter_text
			initial_length_sum += len(row['text'])
			shortened_length_sum += len(shorter_text)
			print_progress()

	saved_length = initial_length_sum - shortened_length_sum
	if initial_length_sum and saved_length:
		saved_percent = (saved_length / initial_length_sum) * 100
		print(f'\nReduced: {saved_length}/{initial_length_sum} characters ({saved_percent:.2f}%)')
	return df

shorten_text_df(short_text_df)
short_text_df.to_csv(short_text_filename, index=False)
print(f'Shortened text data size: {short_text_df.shape}')
short_text_df.head(2)

Shortened text data size: (68, 3)


Unnamed: 0,text,id,finetune_text
0,Welcome to AnnieMation’s webpage. We are an in...,<urn:uuid:76d0f406-290e-41c9-a4bc-2062e6fc6296>,We want to know how to best serve you. For eme...
1,Dark Side Of The Lens.\nI’m sure this film has...,<urn:uuid:2ceaf741-413a-4618-ac03-071d409b2c14>,Dark Side Of The Lens.\nBeautiful film capturi...
