# Function calling dataset

A new dataset based on a sample of Exam_v3 will be created for the purpose of training an LLM agent capable of sending function calls when given a dialogue.

The agent will track the conversation and respond with an appropriate function call given the patient and Assistant responses.

It's tag will be `[Phoropter]`

Function names:
- pd_adjust(side:str, increment:int)
  - adjusts the PD by a given increment
- occlude(side:str)
  - covers an eye
- unocclude(side:str)
  - uncovers an eye
- display_vachart(type:str, rows:list)
  - displays a 3-row visual acuity chart
- resize_chart(direction:['enlarge','minimize'])
  - changes rows on visual acuity chart
- accuracy(letters:str)
  - identifies the most likely chart row and checks the patient's response for accuracy
- loadRx()
  - loads previous prescription
- recordAcuity(side:str, acuity)
  - stores eye's visual acuity
- loadAR()
  - loads autorefractor prescription
- add_sphere(power:int)
  - changes spherical lens power
- add_cylinder(power:int)
  - changes cylindrical lens power
- display_dots()
  - displays astigmatic dots
- patient_choice(choice:str (['1','2','same']))
  - patient's selection 
- display_duochrome()
  - displays sloan chart with red side, green side

In [1]:
from vision_dataset import VisionDatasetCreator, VisionDataset

# avg percentages of exam phase lengths
gs_len=0.14
sp_len=0.25
ac_len=0.50
cv_len=0.11
# percentage of samples taken from each exam phase
gs_ratio=0.15
sp_ratio=0.30
ac_ratio=0.40
cv_ratio=0.15
# minumum dialogue lengths
gs_min = 2
sp_min = 4
ac_min = 8
cv_min = 6
# maximum dialogue lengths
gs_max = 5
sp_max = 14
ac_max = 18
cv_max = 10

sampling_strategy = dict(
    gs=[gs_len, gs_ratio, gs_min, gs_max],
    sp=[sp_len, sp_ratio, sp_min, sp_max],
    ac=[ac_len, ac_ratio, ac_min, ac_max],
    cv=[cv_len, cv_ratio, cv_min, cv_max]
)

data_dir = '/data/datasets/Exam_v3/'
# set seed to get randomization with reprducible results
dataset_creator = VisionDatasetCreator(sampling_strategy, seed=42)

# identify the number of samples from each file in the training set, which has 21 files total
samples = 3
# 25 samples from each file in the training set, which has 21 files total
size = (32*samples)
dataset_creator.load(data_dir, 'train', size)
# 25 samples from each validation file, 3 files total
size = (5*samples)
dataset_creator.load(data_dir, 'val', size)
# 25 samples from each test file, 6 files total
size = (8*samples)
dataset_creator.load(data_dir, 'test', size)

for i in ['train', 'val', 'test']:
    print('\n', i, len(dataset_creator.dataset[i]))

32 files found. Sampling 3 times per file.
sampling file: /data/datasets/Exam_v3/train/000000.txt
sampling file: /data/datasets/Exam_v3/train/000001.txt
sampling file: /data/datasets/Exam_v3/train/000002.txt
sampling file: /data/datasets/Exam_v3/train/000003.txt
sampling file: /data/datasets/Exam_v3/train/000004.txt
sampling file: /data/datasets/Exam_v3/train/000005.txt
sampling file: /data/datasets/Exam_v3/train/000006.txt
sampling file: /data/datasets/Exam_v3/train/000007.txt
sampling file: /data/datasets/Exam_v3/train/000008.txt
sampling file: /data/datasets/Exam_v3/train/000009.txt
sampling file: /data/datasets/Exam_v3/train/000010.txt
sampling file: /data/datasets/Exam_v3/train/000011.txt
sampling file: /data/datasets/Exam_v3/train/000012.txt
sampling file: /data/datasets/Exam_v3/train/000013.txt
sampling file: /data/datasets/Exam_v3/train/000014.txt
sampling file: /data/datasets/Exam_v3/train/000015.txt
sampling file: /data/datasets/Exam_v3/train/000016.txt
sampling file: /data/d

In [2]:
train = VisionDataset(dataset_creator.dataset["train"])
val = VisionDataset(dataset_creator.dataset["val"])
test = VisionDataset(dataset_creator.dataset["test"])

In [3]:
train.data[0]

{'dialogue': [{'role': 'LocalTech',
   'content': "Alright, and I'm gonna actually, let him look at me, my nose specifically."},
  {'role': 'LocalTech', 'content': 'And can her right PD come in, please?'}],
 'response': [{'role': 'Assistant',
   'content': 'Adjusting the right PD. Is that better?'}]}

## New dataset class - PhoropterDataset

In [4]:
def vis2phor(data):
    reformatted = dict(dialogue=list(),
                       response=list())
    for d in data["dialogue"]:
        reformatted['dialogue'].append(d)
    reformatted['response'].append({'role':'Phoropter', 'content':''})
    return reformatted

reformatted = vis2phor(train.data[0])

def dic2txt(datadict):
    txt = []
    for key in datadict.keys():
        for d in datadict[key]:
            txt += f'''[{d['role']}]  {d['content']}\n'''
    return ''.join(txt)

datum = dic2txt(reformatted)


In [5]:
cnt = 0
data_dir = '/data/datasets/phoropter_v1/'
dset = 'train/'
fname = f'{cnt}.txt'.zfill(10)
fpath = f'{data_dir}{dset}{fname}'

def write_file(txt, fpath):
    with open(fpath, 'w') as file:
        file.write(txt)

write_file(datum, fpath = f'{data_dir}{dset}{fname}')

In [6]:
datum

"[LocalTech]  Alright, and I'm gonna actually, let him look at me, my nose specifically.\n[LocalTech]  And can her right PD come in, please?\n[Phoropter]  \n"

In [7]:
def write_dataset(data,data_root, split):
    for idx in range(len(data)):
        reformatted = vis2phor(data[idx])
        datum = dic2txt(reformatted)
        fname = f'{idx}.txt'.zfill(10)
        fpath = f'{data_root}{split}{fname}'
        write_file(datum, fpath)

write_dataset(train.data, data_dir, 'train/')

## Modify existing exam_v3

I'm going to add a phoropter reaction to each line of text in Exam_v3

In [1]:
from vision_dataset import VisionDatasetCreator, VisionDataset
from glob import glob
# avg percentages of exam phase lengths
gs_len=0.14
sp_len=0.25
ac_len=0.50
cv_len=0.11
# percentage of samples taken from each exam phase
gs_ratio=0.15
sp_ratio=0.30
ac_ratio=0.40
cv_ratio=0.15
# minumum dialogue lengths
gs_min = 2
sp_min = 2
ac_min = 2
cv_min = 2
# maximum dialogue lengths
gs_max = 5
sp_max = 5
ac_max = 5
cv_max = 5

sampling_strategy = dict(
    gs=[gs_len, gs_ratio, gs_min, gs_max],
    sp=[sp_len, sp_ratio, sp_min, sp_max],
    ac=[ac_len, ac_ratio, ac_min, ac_max],
    cv=[cv_len, cv_ratio, cv_min, cv_max]
)

data_dir = '/data/datasets/Exam_v3/'
# set seed to get randomization with reprducible results
dataset_creator = VisionDatasetCreator(sampling_strategy, seed=42)

trainPaths = glob('/data/datasets/Exam_v3/train/*.txt')

captions = dataset_creator.read_captions(trainPaths[0])

## Automatic captions where possible
I'll try to leverage the standardization i've done to Exam_v3 to speed up the creation of this dataset.

In [2]:
from string import Template

template = Template('''[Phoropter]  [$func]''')
ac_template = Template('''[Phoropter]  [{'function':'accuracy', 'letters':'$letters'}]''')

standard = "[Phoropter]  [{'function':None}]"

patient_choice2 = '''{'function':'patient_choice','choice':'2'}'''
patient_choice1 = '''{'function':'patient_choice','choice':'1'}'''

right_pd_adjust = '''{'function':'pd_adjust','side':'right', 'increment':-1}'''
left_pd_adjust = '''{'function':'pd_adjust','side':'left', 'increment':-1}'''

occlude_left = '''{'function':'occlude','side':'left'}'''
occlude_right = '''{'function':'occlude','side':'right'}'''
lens_compare = '''{'function':'lens_compare'}'''

resize_chart = '''{'function':'resize_chart','enlarge':true}'''
accuracy = '''{'function':'accuracy', 'letters':'ozrsn'}'''
display_vachart = '''{'function':'display_vachart', 'rows':['20/35', '20/25','20/20']}'''
display_duochrome = '''{'function':'display_duochrome'}'''
red = '''{'function':'add_sphere','side':'right','power':-25}'''
green = '''{'function':'add_sphere','side':'right','power':25}'''
dots = '''[Phoropter]  [{'function':'display_dots'},{'function':'lens_compare'}]'''
block_leftandchart = '''[Phoropter]  [{'function':'occlude', 'side':'left'},{'function':'display_vachart', 'rows':['20/35', '20/25','20/20']}]'''
start_exam = '''[Phoropter]  [{'function':'unocclude', 'side':'left'},{'function':'unocclude', 'side':'right'},{'function':'display_vachart', 'rows':['20/35', '20/25','20/20']}]'''
sphere_start = '''[Phoropter]  [{'function':'loadRx'},{'function':'unocclude', 'side':'left'},{'function':'unocclude', 'side':'right},{'function':'display_vachart', 'rows':['20/35', '20/25','20/20']}]'''
ac_start = '''[Phoropter]  [{'function':'occlude', 'side':'left'},{'function':'unocclude', 'side':'right},{'function':'display_vachart', 'rows':['20/35', '20/25','20/20']}]'''

l = "[LocalTech]  "
a = "[Assistant]  "
p = "[patient__]  "

def has4capitals(text):
    if len(text) < 42:
        return sum(c.isupper() for c in text) > 3
    else:
        return False
def get_letters(text):
    letters = []
    for c in text:
        if c.isupper():
            letters.append(c)
    letters = [l.lower() for l in letters]
    all_caps = ''.join(letters)
    # first character will be P from Patient
    return all_caps[1:]

captions = dataset_creator.read_captions(trainPaths[1])
def phoropter_response(data):
    new = []
    for cap in captions:
        new.append(cap)
        if "guide you though the exam" in cap:
            new.append(start_exam)
            continue
        if "right PD" in cap:
            new.append(template.substitute(func=right_pd_adjust))
            continue
        if "left PD" in cap:
            new.append(template.substitute(func=left_pd_adjust))
            continue
        if cap.startswith(l) and "start" in cap.lower():
            new.append(start_exam)
            continue
        if 'block your left eye' and "the smallest line you can" in cap:
            new.append(block_leftandchart)
            continue
        if 'cover the left eye' and "the smallest line you can" in cap:
            new.append(block_leftandchart)
            continue
        if "the smallest line you can" in cap:
            new.append(template.substitute(func=display_vachart))
            continue
        if "enlarge" in cap:
            new.append(template.substitute(func=resize_chart))
            continue
        if "letters bigger" in cap:
            new.append(template.substitute(func=resize_chart))
            continue
        if 'cover the left eye for you' in cap:
            new.append(template.substitute(func=occlude_left))
            continue
        if 'block your left eye' in cap:
            new.append(template.substitute(func=occlude_left))
            continue
        if 'cover the right eye' in cap:
            new.append(template.substitute(func=occlude_right))
            continue
        if 'block your right eye' in cap:
            new.append(template.substitute(func=occlude_right))
            continue
        if cap.startswith(a + 'Good. Close your eyes for a moment'):
            new.append(sphere_start)
            continue
        if cap.startswith(a + "Good job. Read the smallest line you can with the right eye."):
            new.append(template.substitute(func=occlude_left))
            continue
        if cap.startswith(a + "Good. What's the smallest row you can read with your left eye?"):
            new.append(template.substitute(func=occlude_right))
            continue
        if cap == "[Assistant]  Thank you. Now close your eyes again for me. Are they closed?":
            new.append(ac_start)
            continue
        if cap.lower().startswith(p + "one"):
            new.append(template.substitute(func=patient_choice1))
            continue
        if cap.lower().startswith(p + "two") or cap.lower().startswith(p + "to") or cap.lower().startswith(p + "too"):
            new.append(template.substitute(func=patient_choice2))
            continue
        if cap.lower().startswith(p + "red"):
            new.append(template.substitute(func=red))
            continue
        if cap.lower().startswith(p + "green"):
            new.append(template.substitute(func=green))
            continue
        if "We're going to do the same thing now with dots." in cap:
            new.append(dots)
            continue
        if "Great! Let's compare the dots again." in cap:
            new.append(dots)
            continue
        if "colors in the background." in cap:
            new.append(template.substitute(func=display_duochrome))
            continue
        if "are they" in cap.lower():
            new.append(template.substitute(func=lens_compare))
            continue
        if "dots" in cap:
            new.append(template.substitute(func=lens_compare))
            continue
        if "Which color looks the most clear?" in cap:
            new.append(template.substitute(func=display_duochrome))
            continue
        if has4capitals(cap):
            accuracy = ac_template.substitute(letters=get_letters(cap))
            new.append(accuracy)
            continue
        else:
            new.append(standard)
    return new
new = phoropter_response(captions)
new[40:60]

['[Patient__]  Number two.',
 "[Phoropter]  [{'function':None}]",
 "[Assistant]  We're going to do the same thing now with dots. Are the dots better with one? With two? Or are they the same?",
 "[Phoropter]  [{'function':'display_dots'},{'function':'lens_compare'}]",
 '[Patient__]  to',
 "[Phoropter]  [{'function':'patient_choice','choice':'2'}]",
 '[Assistant]  Okay. Are the dots better with one? With two? Or are they similar?',
 "[Phoropter]  [{'function':'lens_compare'}]",
 '[Patient__]  two.',
 "[Phoropter]  [{'function':'patient_choice','choice':'2'}]",
 '[Assistant]  Again, are the dots better with one, two, or are they similar?',
 "[Phoropter]  [{'function':'lens_compare'}]",
 '[Patient__]  two.',
 "[Phoropter]  [{'function':'patient_choice','choice':'2'}]",
 '[Assistant]  Are they better with one? Or two?',
 "[Phoropter]  [{'function':'lens_compare'}]",
 "[Patient__]  it's the same",
 "[Phoropter]  [{'function':None}]",
 '[Assistant]  Great. Now I will show you choices again, b

In [3]:
def write_file(txt, fpath):
    with open(fpath, 'w') as file:
        file.writelines(l + '\n' for l in txt)
def write_dataset(data,data_root, split):
    for idx in range(len(data)):
        reformatted = vis2phor(data[idx])
        datum = dic2txt(reformatted)
        fname = f'{idx}.txt'.zfill(10)
        fpath = f'{data_root}{split}{fname}'
        write_file(datum, fpath)

data_dir = '/data/datasets/Exam_v3/'
# set seed to get randomization with reprducible results
dataset_creator = VisionDatasetCreator(sampling_strategy, seed=42)

output_dir = '/data/datasets/phoropter_v2/'
for split in ['train', 'val','test']:
    filePaths = sorted(glob('/data/datasets/Exam_v3/' + split + '/*.txt'))
    for idx, f in enumerate(filePaths):

        captions = dataset_creator.read_captions(f)
        phoropter = phoropter_response(captions)
        fname = f'{idx}.txt'.zfill(10)
        fpath = f'{output_dir}{split}/{fname}'
        write_file(phoropter, fpath)

In [4]:
from glob import glob

filePaths = sorted(glob('/data/datasets/Exam_v3/train/*.txt'))
filePaths[-1]

'/data/datasets/Exam_v3/train/000032.txt'

In [5]:
def has4capitals(text):
    if len(text) < 42:
        return sum(c.isupper() for c in text) > 3
    else:
        return False
for i in captions:
    if has4capitals(i):
        print(i, "length: ",len(i))

[Patient__]  K B N R. length:  21
[Patient__]  HBZD? length:  18
[Patient__]  RKBH. length:  18
[Patient__]  OZRSN. length:  19
[Patient__]  O R K S E. length:  23
[Patient__]  NCKHD. length:  19
[Patient__]  CZSHN length:  18
[Patient__]  O R K S E. length:  23
[Patient__]  Yeah, A P E O R F D Z. length:  35


In [12]:
# avg percentages of exam phase lengths
gs_len=0.14
sp_len=0.25
ac_len=0.50
cv_len=0.11
# percentage of samples taken from each exam phase
gs_ratio=0.15
sp_ratio=0.30
ac_ratio=0.40
cv_ratio=0.15
# minumum dialogue lengths
gs_min = 2
sp_min = 2
ac_min = 2
cv_min = 2
# maximum dialogue lengths
gs_max = 5
sp_max = 5
ac_max = 5
cv_max = 5

sampling_strategy = dict(
    gs=[gs_len, gs_ratio, gs_min, gs_max],
    sp=[sp_len, sp_ratio, sp_min, sp_max],
    ac=[ac_len, ac_ratio, ac_min, ac_max],
    cv=[cv_len, cv_ratio, cv_min, cv_max]
)

In [7]:
from vision_dataset import VisionDatasetCreator


data_dir = '/data/datasets/phoropter_v2/'
# set seed to get randomization with reprducible results
dataset_creator = VisionDatasetCreator(sampling_strategy=sampling_strategy, seed=42, assistant=False)

# identify the number of samples from each file in the training set, which has 21 files total
samples = 3
# 25 samples from each file in the training set, which has 21 files total
size = (32*samples)
dataset_creator.load(data_dir, 'train', size)
# 25 samples from each validation file, 3 files total
size = (5*samples)
dataset_creator.load(data_dir, 'val', size)
# 25 samples from each test file, 6 files total
size = (8*samples)
dataset_creator.load(data_dir, 'test', size)

for i in ['train', 'val', 'test']:
    print('\n', i, len(dataset_creator.dataset[i]))

32 files found. Sampling 3 times per file.
sampling file: /data/datasets/phoropter_v2/train/000000.txt
sampling file: /data/datasets/phoropter_v2/train/000001.txt
sampling file: /data/datasets/phoropter_v2/train/000002.txt
sampling file: /data/datasets/phoropter_v2/train/000003.txt
sampling file: /data/datasets/phoropter_v2/train/000004.txt
sampling file: /data/datasets/phoropter_v2/train/000005.txt
sampling file: /data/datasets/phoropter_v2/train/000006.txt
sampling file: /data/datasets/phoropter_v2/train/000007.txt
sampling file: /data/datasets/phoropter_v2/train/000008.txt
sampling file: /data/datasets/phoropter_v2/train/000009.txt
sampling file: /data/datasets/phoropter_v2/train/000010.txt
sampling file: /data/datasets/phoropter_v2/train/000011.txt
sampling file: /data/datasets/phoropter_v2/train/000012.txt
sampling file: /data/datasets/phoropter_v2/train/000013.txt
sampling file: /data/datasets/phoropter_v2/train/000014.txt
sampling file: /data/datasets/phoropter_v2/train/000015.t

In [8]:
from vision_dataset import VisionDataset

train = VisionDataset(dataset_creator.dataset["train"])
val = VisionDataset(dataset_creator.dataset["val"])
test = VisionDataset(dataset_creator.dataset["test"])

In [9]:
train.data[0]

{'dialogue': [{'role': 'Phoropter', 'content': "[{'function':None}]"},
  {'role': 'Assistant', 'content': 'Hello! Is [PATIENT] with us today?'}],
 'response': [{'role': 'Phoropter', 'content': "[{'function':None}]"}]}

# Find 'hzcko' sequences

One technician doesn't like to change the letters on the screen, so they show 'hzcko' to all their patients. I want to randomize those sequences.

First, I'll clean these sequences from Exam_v3

In [23]:
from random import randint
import string
caps = list(string.ascii_uppercase)

idx = randint(0, len(caps))
caps[idx]

'U'

In [32]:
from vision_dataset import VisionDatasetCreator
from glob import glob
import string
from random import randint

def write_file(txt, fpath):
    with open(fpath, 'w') as file:
        file.writelines(l + '\n' for l in txt)

filePaths = sorted(glob('/data/datasets/Exam_v3/test/*.txt'))

creator = VisionDatasetCreator(sampling_strategy=sampling_strategy)

caps = list(string.ascii_uppercase)
for f in filePaths:
    captions = creator.read_captions(f)
    fname = f.split('/')[-1]
    for n in range(len(captions)):
        if 'H Z C' in captions[n]:
            print(f, '\n',captions[n])
            for l in 'HZCKO':
                idx = randint(0, len(caps)-1)
                captions[n]= captions[n].replace(l,caps[idx])
            print(captions[n])
    # write_file(captions,f)


## Same for phoropter_v2

In [47]:
filePaths = sorted(glob('/data/datasets/phoropter_v2/train/*.txt'))

creator = VisionDatasetCreator(sampling_strategy=sampling_strategy)

caps_upper = list(string.ascii_uppercase)
caps_lower = list(string.ascii_lowercase)
for f in filePaths:
    captions = creator.read_captions(f)
    fname = f.split('/')[-1]
    for n in range(len(captions)):
        if 'H Z C' in captions[n]:
            print(f, '\n',captions[n],'\n',captions[n+1][:-8])
            letters = captions[n+1][-8:]
            for l in 'HZCKO':
                idx = randint(0, len(caps_upper)-1)
                captions[n]= captions[n].replace(l,caps_upper[idx])
                letters = letters.replace(l.lower(),caps_lower[idx])
            captions[n+1] = captions[n+1][:-8] + letters
            print(captions[n],'\n',captions[n+1])
    write_file(captions,f)

# Limiting Phoropter repsonses

We are directed to limit the communications with Gordon's code to just patient selections and letter identification.

To this end, no commands, hypothetical or otherwise will be sent to Gordon's code. We will only send the identified patient responses.

Responses that contain any function other than `patient_choice` or `accuracy` will be changed to `None`


In [1]:
from vision_dataset import VisionDatasetCreator, VisionDataset
from glob import glob
# avg percentages of exam phase lengths
gs_len=0.14
sp_len=0.25
ac_len=0.50
cv_len=0.11
# percentage of samples taken from each exam phase
gs_ratio=0.15
sp_ratio=0.30
ac_ratio=0.40
cv_ratio=0.15
# minumum dialogue lengths
gs_min = 2
sp_min = 2
ac_min = 2
cv_min = 2
# maximum dialogue lengths
gs_max = 5
sp_max = 5
ac_max = 5
cv_max = 5

sampling_strategy = dict(
    gs=[gs_len, gs_ratio, gs_min, gs_max],
    sp=[sp_len, sp_ratio, sp_min, sp_max],
    ac=[ac_len, ac_ratio, ac_min, ac_max],
    cv=[cv_len, cv_ratio, cv_min, cv_max]
)

data_dir = '/data/datasets/phoropter_v3/'
# set seed to get randomization with reprducible results
dataset_creator = VisionDatasetCreator(sampling_strategy, seed=42)

trainPaths = glob('/data/datasets/phoropter_v3/train/*.txt')

captions = dataset_creator.read_captions(trainPaths[0])

In [3]:
for dset in ['train', 'val', 'test']:
    paths = sorted(glob(f'/data/datasets/phoropter_v3/{dset}/*.txt'))
    for p in paths:
        captions = dataset_creator.read_captions(p)
        for idx in range(len(captions)):
            if captions[idx].startswith("[Phoropter]"):
                if not captions[idx][13] == "[":
                    print(f"{p}\n{captions[idx]}")

In [9]:
captions[17][13:].startswith("[{'function':'pd_adjust'")

True

In [5]:
standard = "[Phoropter]  [{'function':None}]"
choice_same = "[Phoropter]  [{'function':'patient_choice','choice':'same'}]"
def replacement(txt):
    if txt.startswith('[Ph'):
        if txt[13:].startswith("[{'function':'accuracy'"):
            return txt
        elif txt[13:].startswith("[{'function':'patient_choice'"):
            return txt
        elif txt[13:].startswith("[{'function':'display_duochrome"):
            return choice_same
        elif txt[13:].startswith("[{'function':'display_vachart"):
            return choice_same
        else:
            return standard
    else:
        return txt

for idx in range(len(captions)):
    captions[idx] = replacement(captions[idx])
captions[:25]

['[LocalTech]  hello.',
 "[Phoropter]  [{'function':None}]",
 '[Assistant]  Hello! Is [PATIENT] with us today?',
 "[Phoropter]  [{'function':None}]",
 '[LocalTech]  Yes',
 "[Phoropter]  [{'function':None}]",
 '[Assistant]  Great! I\'m [ASSISTANT]. I\'ll guide you though the exam before you see the doctor today. [LOCALTECH] will get you set up. [LOCALTECH], just say "[ASSISTANT], start the exam", and I\'ll know you\'re ready to start. If the PDs need adjusting, I can do that, too.',
 "[Phoropter]  [{'function':None}]",
 '[LocalTech]  Okay. And nose pads, if you ever wanted to know, are usually more uncomfortable when it comes to normal stuff because they sit on your nose quite tightly.  So if you ever wanted to try a plastic lens, an optician will go over that too, okay?',
 "[Phoropter]  [{'function':None}]",
 "[Patient__]  Right, yeah.  That's why I was looking at the plastic.  It's a lot more gentle.  I wear it a lot of times.",
 "[Phoropter]  [{'function':None}]",
 "[LocalTech]  Alri

# Write new dataset

In [6]:
def write_file(txt, fpath):
    with open(fpath, 'w') as file:
        file.writelines(l + '\n' for l in txt)

for dset in ['train', 'val', 'test']:
    paths = sorted(glob(f'/data/datasets/phoropter_v3/{dset}/*.txt'))
    for p in paths:
        captions = dataset_creator.read_captions(p)
        for idx in range(len(captions)):
            captions[idx] = replacement(captions[idx])
        write_file(captions, p)

# Phoropter V4

Now I'll make a dataset with the objective of training a multi-label text classifier

In [3]:
candidate_labels = ["1", "2", "same", "letters", "other"]
ints = range(len(candidate_labels))
mapping = dict(zip(candidate_labels,ints))
mapping = {k:v for k,v in sorted(mapping.items(), key= lambda x:x[0])}
mapping

{'1': 0, '2': 1, 'letters': 3, 'other': 4, 'same': 2}

In [68]:
lab = 2
def onehot(label):
    num_labels = 5
    new = [0 for i in range(num_labels)]
    new[lab-1] = 1
    return new

onehot(lab)

[0, 1, 0, 0, 0]

In [1]:
from glob import glob
from text_classification_dataset import TextClassificationDatasetCreator

# set seed to get randomization with reprducible results
dataset_creator = TextClassificationDatasetCreator()

dsets= []
for dset in ['train', 'val', 'test']:
    data_dir= f'/data/datasets/phoropter_v4/{dset}'
    data = dataset_creator.dset(dset,data_dir=data_dir)
    dsets.append(data)

for name, dset in zip(['train', 'val', 'test'], dsets):
    print(name)
    for i in range(5):
        print(dset['text'][i], dset['label'][i])

32 files found.
num utterances: 693 labels: 693
5 files found.
num utterances: 113 labels: 113
8 files found.
num utterances: 206 labels: 206
train
Right, yeah.  That's why I was looking at the plastic.  It's a lot more gentle.  I wear it a lot of times. [0, 0, 0, 1, 0]
That's perfect. Thank you. [0, 0, 0, 1, 0]
Okay, O N V E R. [0, 0, 1, 0, 0]
OK.  O H Y Q K. [0, 0, 1, 0, 0]
O N V E R. [0, 0, 1, 0, 0]
val
Thank you. [0, 0, 0, 1, 0]
S is, excuse me, SZR. [0, 0, 1, 0, 0]
Um, S Z R. [0, 0, 1, 0, 0]
O K [0, 0, 1, 0, 0]
yes [0, 0, 0, 1, 0]
test
Hello? [0, 0, 0, 1, 0]
That's me. [0, 0, 0, 1, 0]
And I think that we're ready. [0, 0, 0, 1, 0]
It's blurry, but ZHC. [0, 0, 1, 0, 0]
VHC. [0, 0, 1, 0, 0]


In [2]:
from datasets import Dataset
# !python -m pip install setfit
train = Dataset.from_dict(dsets[0])
val = Dataset.from_dict(dsets[1])
test = Dataset.from_dict(dsets[2])

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from setfit import get_templated_dataset

# dummy_dataset = Dataset.from_dict({})
# train = get_templated_dataset(dummy_dataset, candidate_labels, sample_size=8)
# train

In [4]:
thing = train.shuffle()[:5]


In [5]:
val.shuffle()[:5]

{'text': ['yes they are',
  'The first one.',
  'yeah',
  'K D N R O.',
  "they're practically the same"],
 'label': [[0, 0, 0, 1, 0],
  [0, 0, 0, 0, 1],
  [0, 0, 0, 1, 0],
  [0, 0, 1, 0, 0],
  [0, 1, 0, 0, 0]]}

In [11]:
from setfit import SetFitModel
from sentence_transformers import SentenceTransformer

model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
# model_id ='sentence-transformers/all-MiniLM-L6-v2'
model = SetFitModel.from_pretrained(model_id,
                                    multi_target_strategy="one-vs-rest",
                                    use_differentiable_head=True,
                                    head_params={"out_features": dataset_creator.num_classes})
# model = SentenceTransformer(model_id)

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [12]:
from transformers import EarlyStoppingCallback

# Early stopping patience (number of epochs without improvement)
early_stopping_patience = 2

# Early stopping threshold (minimum relative improvement to continue training)
early_stopping_threshold = -0.001

# Create the callback
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience, early_stopping_threshold)

# def compute_metrics(p):    
#     pred, labels = p
#     pred = np.argmax(pred, axis=1)
#     accuracy = accuracy_score(y_true=labels, y_pred=pred)
#     recall = recall_score(y_true=labels, y_pred=pred)
#     precision = precision_score(y_true=labels, y_pred=pred)
#     f1 = f1_score(y_true=labels, y_pred=pred)    
#     return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

In [13]:
from setfit import Trainer, TrainingArguments
args = TrainingArguments(
    batch_size=64,
    logging_steps=10,
    eval_steps=50,
    save_steps=50,
    num_epochs=4,
    evaluation_strategy="steps",
    save_strategy="steps",
    metric_for_best_model='eval_embedding_loss',
    load_best_model_at_end=True,
)
trainer = Trainer(model=model,
                  train_dataset=train,
                  eval_dataset=val,
                  callbacks=[early_stopping_callback], 
                  args=args)

                                                   

In [14]:
trainer.train()
trainer.eval_dataset = test
metrics = trainer.evaluate()
print(metrics)

***** Running training *****
  Num unique pairs = 316214
  Batch size = 64
  Num epochs = 4
  Total optimization steps = 19764
  0%|          | 0/19764 [00:00<?, ?it/s]

[A                                                

{'embedding_loss': 0.2239, 'learning_rate': 1.0116337885685383e-08, 'epoch': 0.0}



[A                                                 

{'embedding_loss': 0.2444, 'learning_rate': 1.0116337885685383e-07, 'epoch': 0.0}



[A                                                 

{'embedding_loss': 0.2321, 'learning_rate': 2.0232675771370766e-07, 'epoch': 0.0}



[A                                                 

{'embedding_loss': 0.2406, 'learning_rate': 3.034901365705615e-07, 'epoch': 0.01}



[A                                                 

{'embedding_loss': 0.2352, 'learning_rate': 4.046535154274153e-07, 'epoch': 0.01}



[A                                                 

{'embedding_loss': 0.2154, 'learning_rate': 5.058168942842692e-07, 'epoch': 0.01}


                                                 
[A                                                 

{'eval_embedding_loss': 0.1965, 'learning_rate': 5.058168942842692e-07, 'epoch': 0.01}



[A                                                 

{'embedding_loss': 0.1802, 'learning_rate': 6.06980273141123e-07, 'epoch': 0.01}



[A                                                 

{'embedding_loss': 0.2137, 'learning_rate': 7.081436519979768e-07, 'epoch': 0.01}



[A                                                 

{'embedding_loss': 0.2082, 'learning_rate': 8.093070308548306e-07, 'epoch': 0.02}



[A                                                 

{'embedding_loss': 0.2054, 'learning_rate': 9.104704097116844e-07, 'epoch': 0.02}



[A                                                  

{'embedding_loss': 0.2082, 'learning_rate': 1.0116337885685384e-06, 'epoch': 0.02}


                                                 
[A                                                  

{'eval_embedding_loss': 0.1652, 'learning_rate': 1.0116337885685384e-06, 'epoch': 0.02}



[A                                                  

{'embedding_loss': 0.1585, 'learning_rate': 1.112797167425392e-06, 'epoch': 0.02}



[A                                                  

{'embedding_loss': 0.155, 'learning_rate': 1.213960546282246e-06, 'epoch': 0.02}



[A                                                  

{'embedding_loss': 0.1592, 'learning_rate': 1.3151239251391e-06, 'epoch': 0.03}



[A                                                  

{'embedding_loss': 0.154, 'learning_rate': 1.4162873039959535e-06, 'epoch': 0.03}



[A                                                  

{'embedding_loss': 0.1637, 'learning_rate': 1.5174506828528073e-06, 'epoch': 0.03}


                                                 
[A                                                  

{'eval_embedding_loss': 0.1282, 'learning_rate': 1.5174506828528073e-06, 'epoch': 0.03}



[A                                                  

{'embedding_loss': 0.1247, 'learning_rate': 1.6186140617096613e-06, 'epoch': 0.03}



[A                                                  

{'embedding_loss': 0.132, 'learning_rate': 1.7197774405665153e-06, 'epoch': 0.03}



[A                                                  

{'embedding_loss': 0.1146, 'learning_rate': 1.8209408194233688e-06, 'epoch': 0.04}



[A                                                  

{'embedding_loss': 0.1051, 'learning_rate': 1.9221041982802226e-06, 'epoch': 0.04}



[A                                                  

{'embedding_loss': 0.115, 'learning_rate': 2.023267577137077e-06, 'epoch': 0.04}


                                                 
[A                                                  

{'eval_embedding_loss': 0.0889, 'learning_rate': 2.023267577137077e-06, 'epoch': 0.04}



[A                                                  

{'embedding_loss': 0.1061, 'learning_rate': 2.12443095599393e-06, 'epoch': 0.04}



[A                                                  

{'embedding_loss': 0.1087, 'learning_rate': 2.225594334850784e-06, 'epoch': 0.04}



[A                                                  

{'embedding_loss': 0.098, 'learning_rate': 2.326757713707638e-06, 'epoch': 0.05}



[A                                                  

{'embedding_loss': 0.0983, 'learning_rate': 2.427921092564492e-06, 'epoch': 0.05}



[A                                                  

{'embedding_loss': 0.0951, 'learning_rate': 2.5290844714213457e-06, 'epoch': 0.05}


                                                 
[A                                                  

{'eval_embedding_loss': 0.0676, 'learning_rate': 2.5290844714213457e-06, 'epoch': 0.05}



[A                                                  

{'embedding_loss': 0.0505, 'learning_rate': 2.6302478502782e-06, 'epoch': 0.05}



[A                                                  


{'embedding_loss': 0.0699, 'learning_rate': 2.7314112291350532e-06, 'epoch': 0.05}


  1%|▏         | 271/19764 [01:45<1:22:28,  3.94it/s][A
[A                                                  

{'embedding_loss': 0.0385, 'learning_rate': 2.832574607991907e-06, 'epoch': 0.06}



[A                                                  

{'embedding_loss': 0.0543, 'learning_rate': 2.933737986848761e-06, 'epoch': 0.06}



[A                                                  

{'embedding_loss': 0.0766, 'learning_rate': 3.0349013657056146e-06, 'epoch': 0.06}


                                                 
[A                                                  

{'eval_embedding_loss': 0.0544, 'learning_rate': 3.0349013657056146e-06, 'epoch': 0.06}



[A                                                  

{'embedding_loss': 0.0538, 'learning_rate': 3.1360647445624688e-06, 'epoch': 0.06}



[A                                                  

{'embedding_loss': 0.0491, 'learning_rate': 3.2372281234193226e-06, 'epoch': 0.06}



[A                                                  

{'embedding_loss': 0.0333, 'learning_rate': 3.3383915022761763e-06, 'epoch': 0.07}



[A                                                  

{'embedding_loss': 0.0401, 'learning_rate': 3.4395548811330305e-06, 'epoch': 0.07}



[A                                                  

{'embedding_loss': 0.0603, 'learning_rate': 3.540718259989884e-06, 'epoch': 0.07}


                                                 
[A                                                  

{'eval_embedding_loss': 0.0463, 'learning_rate': 3.540718259989884e-06, 'epoch': 0.07}



[A                                                  

{'embedding_loss': 0.0433, 'learning_rate': 3.6418816388467377e-06, 'epoch': 0.07}



[A                                                  

{'embedding_loss': 0.0442, 'learning_rate': 3.7430450177035914e-06, 'epoch': 0.07}



[A                                                  

{'embedding_loss': 0.0347, 'learning_rate': 3.844208396560445e-06, 'epoch': 0.08}



[A                                                  

{'embedding_loss': 0.0384, 'learning_rate': 3.945371775417299e-06, 'epoch': 0.08}



[A                                                  

{'embedding_loss': 0.0282, 'learning_rate': 4.046535154274154e-06, 'epoch': 0.08}


                                                 
[A                                                  

{'eval_embedding_loss': 0.0422, 'learning_rate': 4.046535154274154e-06, 'epoch': 0.08}



[A                                                  

{'embedding_loss': 0.0255, 'learning_rate': 4.147698533131007e-06, 'epoch': 0.08}



[A                                                  

{'embedding_loss': 0.0705, 'learning_rate': 4.24886191198786e-06, 'epoch': 0.09}



[A                                                  

{'embedding_loss': 0.0239, 'learning_rate': 4.3500252908447145e-06, 'epoch': 0.09}



[A                                                  

{'embedding_loss': 0.0203, 'learning_rate': 4.451188669701568e-06, 'epoch': 0.09}



[A                                                  

{'embedding_loss': 0.0283, 'learning_rate': 4.552352048558422e-06, 'epoch': 0.09}


                                                 
[A                                                  

{'eval_embedding_loss': 0.0412, 'learning_rate': 4.552352048558422e-06, 'epoch': 0.09}



[A                                                  

{'embedding_loss': 0.0446, 'learning_rate': 4.653515427415276e-06, 'epoch': 0.09}



[A                                                  

{'embedding_loss': 0.0114, 'learning_rate': 4.75467880627213e-06, 'epoch': 0.1}



[A                                                  

{'embedding_loss': 0.0106, 'learning_rate': 4.855842185128984e-06, 'epoch': 0.1}



[A                                                  

{'embedding_loss': 0.0133, 'learning_rate': 4.957005563985838e-06, 'epoch': 0.1}



[A                                                  

{'embedding_loss': 0.0091, 'learning_rate': 5.058168942842691e-06, 'epoch': 0.1}


                                                 
[A                                                  

{'eval_embedding_loss': 0.0417, 'learning_rate': 5.058168942842691e-06, 'epoch': 0.1}



[A                                                  

{'embedding_loss': 0.0368, 'learning_rate': 5.159332321699545e-06, 'epoch': 0.1}



[A                                                  

{'embedding_loss': 0.033, 'learning_rate': 5.2604957005564e-06, 'epoch': 0.11}



[A                                                  

{'embedding_loss': 0.0234, 'learning_rate': 5.361659079413253e-06, 'epoch': 0.11}



[A                                                  

{'embedding_loss': 0.0242, 'learning_rate': 5.4628224582701065e-06, 'epoch': 0.11}



[A                                                  

{'embedding_loss': 0.0223, 'learning_rate': 5.563985837126961e-06, 'epoch': 0.11}


                                                 
[A                                                  

{'eval_embedding_loss': 0.0425, 'learning_rate': 5.563985837126961e-06, 'epoch': 0.11}


Loading best SentenceTransformer model from step 450.

  3%|▎         | 550/19764 [03:46<2:11:40,  2.43it/s]
The `max_length` is `None`. Using the maximum acceptable length according to the current model body: 512.


{'train_runtime': 226.1461, 'train_samples_per_second': 5593.269, 'train_steps_per_second': 87.395, 'epoch': 0.11}


Epoch: 100%|██████████| 4/4 [00:36<00:00,  9.08s/it]
***** Running evaluation *****


{'accuracy': 0.912621359223301}


In [16]:
model_id = "/home/digitalopt/proj/diarization/checkpoints/step_550"
# model_id ='sentence-transformers/all-MiniLM-L6-v2'
model = SetFitModel.from_pretrained(model_id,
                                    multi_target_strategy="one-vs-rest",
                                    use_differentiable_head=True,
                                    head_params={"out_features": dataset_creator.num_classes})


TypeError: SetFitModel.__init__() got an unexpected keyword argument 'head_params'