### Imports

In [6]:
import requests
import pandas as pd
from random_word import Wordnik
import time
import math
import numpy as np
import spacy

  from .autonotebook import tqdm as notebook_tqdm


### Functions to scrape prompts from Lexica API

In [10]:
def generate_search_strings(num_items: int, counter: int):
    '''generate random words as search strings for lexica'''    
    batch_size = 10
    num_calls = math.ceil(num_items/batch_size)
    output = []
    wordnik_service = Wordnik()
    for i in range(num_calls):
        try:
            # Return a single random word
            res = wordnik_service.get_random_words(includePartOfSpeech ="noun,verb,adverb",hasDictionaryDef=True, limit=batch_size)
            #TODO: Check if adding a duplicate search term
            [output.append(x) for x in res]
        except Exception as e:
            print(e)
    print('Generated ', int(batch_size * num_calls), ' search terms.' )
    return output

def lexica_search(terms: list, counter: int):
    '''search and store lexica results via their locked-down and rate-limited api'''
    search_base='https://lexica.art/api/v1/search?q='
    prompts = pd.DataFrame(columns=['search_string','source','prompt'])
    print('Starting counter is: ', counter)
    for i in range(counter,len(terms)):
        #print('Searching term: ' , item)
        query = terms[i]
        query = query.replace(' ', '+')
        try:
            d = requests.get(url=(search_base + query))
            data = d.json()
            obj = data['images']
        except Exception as e:
            print(e)
            print('Writing counter to file: ', counter)
            time.sleep(35)
            f = open('./counter.txt', 'w')
            f.write(str(counter))
            f.close()
            return prompts, counter
        #print('Adding items to db for search term: ', item)
        for item in obj:
            row = [query, item['src'], item['prompt']]
            prompts.loc[item['id']] = row
        counter +=1
        print('Commited prompts for term ', counter, ' out of ', len(terms))
        time.sleep(.5)
    f = open('./counter.txt', 'w')
    f.write(counter)
    f.close()
    return prompts, counter

### Creating a database of prompts for training

In [11]:
with open('./1000-most-common.txt') as file:
    lines = file.readlines()
    common = [line.rstrip() for line in lines]
with open('./counter.txt') as file:
    lines = file.readlines()
    counter = int(lines[0])

while counter != (len(common)-1):
    print("Starting the procedure again with counter: ", counter)
    res_common, counter = lexica_search(terms = common, counter = counter)
    filename = './prompts-with-common-' + str(counter) + '.json'
    res_common.to_json(filename, orient='split')


#res.to_json('./common-df.json',orient='split')
res = pd.read_json('./common-df.json', orient='split')
master = pd.read_json('./master-prompts.json', orient='split')

full = pd.concat([master,res])
full.shape
full.to_csv('./full-prompts.txt')

Starting the procedure again with counter:  927
Starting counter is:  927
Commited prompts for term  928  out of  1000
Commited prompts for term  929  out of  1000
Commited prompts for term  930  out of  1000
Commited prompts for term  931  out of  1000


KeyboardInterrupt: 

In [7]:
# Tokenizing prompts w/spacy
full = pd.read_json('./full-prompts.json', orient='split')
nlp = spacy.load('en_core_web_sm')
def tokenize(prompt):
    temp = []
    f = nlp(prompt)
    for ent in f.ents:
        temp.append({'token': ent.text,'char_start': ent.start_char, 'char_end': ent.end_char, 'label': None, 'is_weak_label': False, 'pos': ent.label_})
    return temp

full['tokens'] = full['prompt'].apply(tokenize)

### Importing Annotations

For experimentation purposes, I used the community version of Label Studio to annotate ~230 prompts. The label schema is [ARTIST, OTHER]. Label studio ground truth needs to be transformed to a spacy-compatible format.

In [8]:
#Adding GT for the prompts based off of labels from Label Studio
import json
f = open('./gt.json')
gt_file = json.load(f)
#print('Example entry: ', gt_file[58])

filtered = []
for x in gt_file:
    filtered.append({'annotations': x['annotations'][0]['result'],'data': x['data']})
#print(filtered[58])

full['gt_raw'] = None
for i in range(full.shape[0]):
    ss = full.index[i]
    for item in filtered:
        if item['data']['Unnamed: 0'] == ss:
            #print('found annoation match')
            #print('df row: ', full.loc[ss])
            #print('annotations row: ', item)
            full['gt_raw'][i] = item

In [13]:
'''
#count = 0
#for i in range(df.shape[0]):
    #if df['gt_raw'][i] is not None:
        #for item in df['gt_raw'][i]['annotations']:
            #print(item['value'])

trim = df[df['gt_raw'].notnull()]
trim.shape
trim.to_json('./trim-df.json', orient='split')
'''

"\n#count = 0\n#for i in range(df.shape[0]):\n    #if df['gt_raw'][i] is not None:\n        #for item in df['gt_raw'][i]['annotations']:\n            #print(item['value'])\n\ntrim = df[df['gt_raw'].notnull()]\ntrim.shape\ntrim.to_json('./trim-df.json', orient='split')\n"

### Weak Labeling with BART Large MNLI

Weak supervision is a helpful technique when working with few or no labeled examples. Here, I demonstrate using BART LLM as a source of weak signal for labeling. For each entity that has a "PERSON" part-of-speech tag from spacy, ask BART whether this person is an artist or not. If prob(Yes) > 0.85, weakly label example as ARTIST.

In [188]:

from transformers import pipeline
classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli")

In [199]:
temp = full.iloc[222]['tokens']
print(temp)
labels = ['artist', 'other']
threshold = 0.80
for item in temp:
    if item['label'] is None and item['pos'] == 'PERSON':
        res = classifier(item['token'], labels)
        print(res['sequence'],' : ', res['scores'][0])
        if (res['scores'][0] > threshold):
            item['label'] = 'artist'
            item['is_weak_label'] = True
            #print(item)
            annotation = annotation['annotations'].append(item)

#Omitting adding these weak labels to the training set for now, since the model was able to get strong scores to start.
for i in range(full.shape[0]):
    temp = full.iloc[i]['tokens']
    for item in temp:
        if item['label'] is None and item['pos'] == 'PERSON':
            res = classifier(item['token'], labels)
            print(res['sequence'],' : ', res['scores'][0])
            if (res['scores'][0] > threshold):
                item['label'] = 'artist'
                item['is_weak_label'] = True

#full.to_json('./full-checkpoint.json', orient = 'split')

[{'token': 'james jean', 'char_start': 162, 'char_end': 172, 'label': 'artist', 'is_weak_label': True, 'pos': 'PERSON'}]


### Prepare data for training
Spacy needs data in it's Doc() object form. In this section, we trim the DF to only strongly labeled examples and convert the existing dataframe into docs and write it to disk.

[{'value': {'start': 168, 'end': 192, 'text': 'cornelis van poelenburgh', 'labels': ['Artist']}, 'id': '-PCrj-bI4Z', 'from_name': 'label', 'to_name': 'text', 'type': 'labels', 'origin': 'manual'}, {'value': {'start': 197, 'end': 208, 'text': 'dosso dossi', 'labels': ['Artist']}, 'id': 'y01Q-cujWT', 'from_name': 'label', 'to_name': 'text', 'type': 'labels', 'origin': 'manual'}]


In [42]:
trim = pd.read_json('./trim-df.json', orient='split')
#print(trim['prompt'][0])
#print(trim['gt_raw'][0]['annotations'])

''''
[("a special operations member that looks like colin farrell and brad pitt, in battle, covert military pants, military boots, greek mythology, oil reinassance painting by cornelis van poelenburgh and dosso dossi, ultra detailed, concept art, 8 k what",[(168,182,artist),(197,208,artist)]) ...]
'''
# Need to add 'other' annotations as assumed negative label

#print(trim['tokens'][0])
#print(trim['gt_raw'][0]['annotations'])

for i in range(trim.shape[0]):
    tokens = trim['tokens'][i]
    existing_annotations = trim['gt_raw'][i]['annotations']
    for item in tokens:
        exists = False
        for check in existing_annotations:
            if item['char_start'] == check['value']['start'] or item['char_end'] == check['value']['end']:
                exists = True
        if exists == False:
            # add token with 'other' label
            existing_annotations.append({'value': {'start': item['char_start'], 'end': item['char_end'], 'labels':['Other'] }})

training_data = []
for i in range(trim.shape[0]):
    anno = []
    for item in trim['gt_raw'][i]['annotations']:
        tmp = item['value']
        add = (tmp['start'], tmp['end'], tmp['labels'][0])
        anno.append(add)
    training_data.append((trim['prompt'][i],anno))
print(training_data[22:30])
print(len(training_data))

[('a d & d character portrait of a beautiful noble elf princess with blonde hair, regal jewellry and elegant dress by bowater, charlie ', [(115, 122, 'Artist'), (124, 131, 'Artist')]), ('nicki minaj hugged by barack obama from behind, soviet colored propaganda poster, highly detailed illustration ', [(0, 11, 'Other'), (48, 54, 'Other')]), ('a beautiful painting of an indigenous man blowing tobacco snuff into the nose of another man , fantasy art, matte painting, highly detailed', []), ('File sharing website design', []), ('rage comics meme from the year 2 0 3 0. ', [(22, 38, 'Other')]), ('Arsenal win the Premier League', []), ('traditional japanese tiger drawing by junji ito, ', [(38, 47, 'Artist'), (12, 20, 'Other')]), ('foundation with gold, silver, precious stones, wood, hay, straw', [])]
230


In [43]:
# Convert dataframe to Docs for training
import spacy
from spacy.tokens import DocBin
nlp = spacy.blank("en")
db = DocBin()

'''
training_data = [
  ("Tokyo Tower is 333m tall.", [(0, 11, "BUILDING")]),
]
'''

for text, annotations in training_data:
    doc = nlp(text)
    #print(doc)
    ents = []
    #print(text)
    #print(annotations)
    for start, end, label in annotations:
        span = doc.char_span(start, end, label=label)
        #print(span.text)
        if span is not None:
            ents.append(span)
    doc.ents = ents
    
    db.add(doc)
db.to_disk("./train.spacy")

### Training an NER model
Model training is managed via spacy config files (`prompt_config.cfg`) and the command line. See **training-pipeline-output.txt** for checkpoints, loss, and overall score.


## Section 2: Evaluation with a new dataset

In [None]:
from datasets import load_dataset
import spacy
from spacy.tokens import DocBin
nlp = spacy.load("en_core_web_sm")

dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts", split = 'train')

In [None]:
eval_set = []
for item in dataset['Prompt'][0:100]:
   eval_set.append(item)
db = DocBin()
for item in eval_set:
    doc = nlp(item)
    db.add(doc)
db.to_disk('./eval-docs.spacy')

### Conclusion

After running `spacy-eval.sh`, the predictions are stored in the `eval` directory. Traditionally, I would label a GT eval set by hand. You'll see that model predictions for the holdout set are anecdotally highly accurate. Would generally look at some doc and token level metrics for an official score, and things like confusion matrix, class level errors, and precision/recall curve to determine how/where I would fine tune.

Some novel applications of this model:
* Prompt optimization: Artist names are analyzed against prompt outputs to determine optimal artist names and locations in a prompt
* A model like this begins to attribute "credit" to the various artist's work who were included in the prompt. The concern of attribution for artists in prompts is an unresolved issue. A model like enables the conversation to continue.