In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import time
import random
import json
import gc

import PIL
from PIL import Image

import numpy as np
import pandas as pd
import torch
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import nibabel as nib
from einops import rearrange
from scipy import ndimage
from sklearn.neighbors import NearestNeighbors

dir2 = os.path.abspath('../..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)
    
from research.data.natural_scenes import NaturalScenesDataset
from research.experiments.nsd.nsd_access import NSDAccess
from research.metrics.metrics import cosine_distance, top_knn_test

In [2]:
nsd_path = Path('D:\\Datasets\\NSD\\')
nsd = NaturalScenesDataset(nsd_path, coco_path='X:\\Datasets\\COCO')
stimuli_path = nsd_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'
stimulus_images = h5py.File(stimuli_path, 'r')['imgBrick']

In [5]:
def get_similar_images(images, captions, embeddings, image_id, thresholds, shuffle=True):
    N = embeddings.shape[0]
    distances, nearest_image_ids = neighbors.kneighbors(embeddings[image_id, None], n_neighbors=N)
    distances, nearest_image_ids = distances[0], nearest_image_ids[0]

    t_id = 0
    k_values = [0]
    for k, d in enumerate(distances):
        if t_id >= len(thresholds):
            break
        if d > thresholds[t_id]:
            k_values.append(k)
            t_id += 1

    similar_images = []
    for k in k_values:
        near_image_id = nearest_image_ids[k]
        similar_images.append({
            'image': stimulus_images[near_image_id], 
            'caption': str(captions[near_image_id]),
            'cosine_distance': float(distances[k]),
            'stim_id': int(near_image_id),
            'k': int(k)
        })

    if shuffle:
        shuffle_ids = np.arange(len(similar_images))
        np.random.shuffle(shuffle_ids)
        similar_images = [similar_images[i] for i in shuffle_ids]

    return similar_images

In [6]:
model_name = 'ViT-B=32'
stimulus_key = 'embedding'

save_key = stimulus_key
save_model_name = model_name

stimulus_file = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}.hdf5', 'r')
x = stimulus_file[stimulus_key][:]

stimulus_file_text = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}-text.hdf5', 'r')
x_text = stimulus_file_text[stimulus_key][:]
x_text = x_text / np.linalg.norm(x_text, axis=-1, keepdims=True)

ids = np.stack([np.arange(73000) for _ in range(5)], axis=-1)
print(ids.shape)

#random_ids = np.arange(73000)
#np.random.shuffle(random_ids)
#print(random_ids)
#x_text = x_text[random_ids]

text_dists = np.einsum('ni,nti->nt', x, x_text)
print(text_dists)

neighbors = NearestNeighbors(metric='cosine')
neighbors.fit(x)

all_captions = np.array([nsd.load_coco(i)[:5] for i in tqdm(range(73000))])
best_captions = all_captions[np.arange(73000), np.argmax(text_dists, axis=1)]

#top_knn_test(x, x_text.reshape(-1, 512), ids.flatten(), k=[1, 5, 10], metric='cosine')

(73000, 5)
[[0.3166824  0.29903737 0.2705831  0.26404428 0.3248881 ]
 [0.3135147  0.29028302 0.30885544 0.29989666 0.2975764 ]
 [0.35390684 0.2882279  0.3490643  0.35402402 0.32551354]
 ...
 [0.28977567 0.2876259  0.29651064 0.30269492 0.31399542]
 [0.31555313 0.31223205 0.30614698 0.3007291  0.26928315]
 [0.29240453 0.28594497 0.32062352 0.29156315 0.2844131 ]]


  0%|          | 0/73000 [00:00<?, ?it/s]

In [7]:
# Results viewer

model_name = 'clip-vit-large-patch14-text'

fold_subset = 'special100'
subset_file = nsd_path / f'nsddata/stimuli/nsd/{fold_subset}.tsv'
subset_stimulus_ids = (np.array(pd.read_csv(subset_file, header=None)[0]) - 1).tolist()

def get_stim_id(p: Path):
    return int(p.name.split('_')[1].split('-')[1])

groups = ['group-5', 'group-6']
run_names = [f'run-002', 'run-003', 'run-004', 'run-005']
subjects = [f'subj0{i}' for i in range(1, 9)]
modes = ['Ground Truth', 'Image Experiment', 'Caption Experiment']

rois = ['', 'Primary_Visual', 'Early_Visual', 'Dorsal_Stream_Visual', 'Ventral_Stream_Visual',
        'MT+_Complex_and_Neighboring_Visual_Areas', 'Medial_Temporal', 'Lateral_Temporal',
        'Temporo-Parieto-Occipital_Junction', 'Superior_Parietal', 'Inferior_Parietal',
        'Posterior_Cingulate', 'Frontal']

@interact(ground_truth_reconstructions=False, 
          run_name=run_names, 
          subject=(1, 8), 
          stim_id=subset_stimulus_ids, 
          mode=modes,
          group_name=groups,
          options=[2, 5],
          roi=rois)
def select(ground_truth_reconstructions, run_name, subject, stim_id, mode, group_name, options, roi):
    subject = subjects[subject-1]
    if ground_truth_reconstructions:
        reconstructions_path = nsd_path / f'derivatives/reconstructions/{model_name}/ground_truth/run-001/images'
    else:
        reconstructions_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/{roi}/images'
    reconstruction_files = [p for p in reconstructions_path.iterdir() if p.name != 'desktop.ini']
    reconstruction_files.sort(key=get_stim_id)
    stimulus_ids = np.array([get_stim_id(p) for p in reconstruction_files])
    unique_stimulus_ids = np.unique(stimulus_ids)

    neighbors = NearestNeighbors(metric='cosine')
    neighbors.fit(x)

    recon_ids = np.where(stimulus_ids == stim_id)[0]
    if recon_ids.shape[0] < 1:
        print(f'Stim id {stim_id} not presented to {subject}.')
        return
    reconstruction_img = [
        np.array(Image.open(reconstruction_files[recon_id]))
        for recon_id in recon_ids
    ]
    if len(reconstruction_img) == 3:
        reconstruction_img = np.concatenate(reconstruction_img, axis=1)
    elif len(reconstruction_img) == 9:
        reconstruction_img = np.concatenate([
            np.concatenate(reconstruction_img[:3], axis=1),
            np.concatenate(reconstruction_img[3:6], axis=1),
            np.concatenate(reconstruction_img[6:], axis=1),
        ], axis=0)
    
    print("Computer generated images:")
    plt.figure(figsize=(9, 9))
    plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    plt.imshow(reconstruction_img)
    plt.show()
    
    ground_truth_img = stimulus_images[stim_id]
    captions = nsd.load_coco(stim_id)
    
    if mode == 'Ground Truth':
        print("Original stimulus")
        plt.figure(figsize=(3, 3))
        plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        plt.imshow(ground_truth_img)
        plt.show()
        
        print("Stimulus captions")
        for i, c in enumerate(captions):
            print(f'caption_similarity={text_dists[stim_id, i]:.2f}', c)
        return
    
    else:
        if options == 2:
            thresholds = [0.3]
        elif options == 5:
            thresholds = [0.3, 0.4, 0.5, 0.6]
        similar_images = get_similar_images(
            images=stimulus_images, 
            captions=best_captions, 
            embeddings=x, 
            image_id=stim_id, 
            thresholds=thresholds,
            shuffle=True
        )
        
        if mode == 'Image Experiment':
            print("Choose the image that is most similar to the computer generated images.")
            fsize = (len(thresholds) + 1) * 3
            plt.figure(figsize=(fsize, fsize))
            plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
            plt.imshow(np.concatenate([similar_image['image'] for similar_image in similar_images], axis=1))
            plt.show()
            
            @interact(show_ans=False)
            def show(show_ans):
                if show_ans:
                    for i, similar_image in enumerate(similar_images):
                        tag = '(ANSWER)' if similar_image['k'] == 0 else f'd={similar_image["cosine_distance"]:.2f}'
                        print(f'Image {i + 1}, caption: "{similar_image["caption"]}" {tag}')
                            
        if mode == 'Caption Experiment':
            print("Choose the sentence that best describes the computer generated images.")
            @interact(show_ans=False)
            def show(show_ans):
                for i, similar_image in enumerate(similar_images):
                    if show_ans:
                        tag = '(ANSWER)' if similar_image['k'] == 0 else f'd={similar_image["cosine_distance"]:.2f}'
                    else:
                        tag = ''
                    print(f'{i+1}) {similar_image["caption"]} {tag}')

            
                    

interactive(children=(Checkbox(value=False, description='ground_truth_reconstructions'), Dropdown(description=…

In [33]:
# Generate CSV data for MTurk experiments

from googleapiclient.discovery import build

model_name = 'clip-vit-large-patch14-text'
group_name = 'group-5'
run_names = ['run-003']
subjects = [f'subj0{i}' for i in range(1, 9)]

results = {}

for run_name in run_names:
    for subject in subjects:
        API_KEY = 'AIzaSyCsVkNto4IU5yX2pOSeguqknNgbOlDvoVU'
        service = build('drive', 'v3', developerKey=API_KEY)

        folder_id = '18-b4AWlYcR5njrkVAVakqn8AiYx_a8Dj'
        
        query_shared = {
            "includeItemsFromAllDrives": True,
            "supportsAllDrives": True,
            "fields": "*"
        }
        
        for folder_name in (model_name, group_name, run_name, subject, 'merged'):
            result = service.files().list(
                q=f"'{folder_id}' in parents",
                pageToken=None,
                **query_shared
            ).execute()
            files = result.get('files')
            
            for file in files:
                if file['name'] == folder_name:
                    folder_id = file['id']
                    break
        images = []
        next_page_token = None
        while True:
            result = service.files().list(
                q=f"'{folder_id}' in parents", 
                pageToken=next_page_token,
                **query_shared
            ).execute()
            images += result.get("files")
            next_page_token = result.get("nextPageToken")
            if next_page_token is None:
                break
        key = '/'.join((model_name, group_name, run_name, subject))
        results[key] = {
            int(img['name'].split('_')[1].split('-')[1].split('.')[0]): img['id']
            for img in images
        }
        print(key, len(images))
        

clip-vit-large-patch14-text/group-5/run-003/subj01 80
clip-vit-large-patch14-text/group-5/run-003/subj02 80
clip-vit-large-patch14-text/group-5/run-003/subj03 79
clip-vit-large-patch14-text/group-5/run-003/subj04 78
clip-vit-large-patch14-text/group-5/run-003/subj05 81
clip-vit-large-patch14-text/group-5/run-003/subj06 80
clip-vit-large-patch14-text/group-5/run-003/subj07 80
clip-vit-large-patch14-text/group-5/run-003/subj08 78


In [135]:
# Save all of the google ids for ground truth images.

folder_id = '1tzk0iMAYLjo6LeOu5iR-Oz-oPKqs48n7'
images = []
next_page_token = None
while True:
    result = service.files().list(
        q=f"'{folder_id}' in parents", 
        pageToken=next_page_token,
        **query_shared
    ).execute()
    images += result.get("files")
    next_page_token = result.get("nextPageToken")
    if next_page_token is None:
        break
    if len(images) % 1000 == 0:
        print(len(images))
gt_google_ids = {
    int(img['name'].split('.')[0]): img['id']
    for img in images if img['name'].endswith('png')
}

1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000


In [138]:
import json

with open(nsd_path / 'derivatives/stim_google_ids.json', 'w') as f:
    f.write(json.dumps(gt_google_ids))

In [34]:
with open(nsd_path / 'derivatives/stim_google_ids.json', 'r') as f:
    gt_google_ids = json.loads(f.read())

In [41]:
# Generate CSV files for tasks.
import json

task_version = 'version-2'

stimuli_path = nsd_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'
stimulus_images = h5py.File(stimuli_path, 'r')['imgBrick']

for run_name in run_names:
    for subject in subjects:
        print(subject, run_name)
        key = f'{model_name}/{group_name}/{run_name}/{subject}'
        drive_ids = results[key]
        
        img_url = 'https://drive.google.com/uc?id={}&export=view'
        
        task2_txt = 'key_image_url,comparison_image_urls\n'
        #task3_txt = 'image_url,A,B,C,D,E\n'
        task3_txt = 'image_url,A,B\n'
        
        all_images = []
        for stim_id, drive_id in drive_ids.items():
            similar_images = get_similar_images(
                images=stimulus_images, 
                captions=best_captions, 
                embeddings=x, 
                image_id=stim_id, 
                thresholds=[0.3,],# 0.4, 0.5, 0.6],
                shuffle=True
            )
            for img in similar_images:
                del img['image']
            all_images.append(similar_images)
            generated_img_url = img_url.format(drive_id)
            captions = [img["caption"].replace("\"", "\'") for img in similar_images]
            captions = [f'"{caption}"' for caption in captions]
            comparison_img_urls = [
                img_url.format(gt_google_ids[str(img['stim_id'])])
                for img in similar_images
            ]
            
            task2_txt += f'{generated_img_url},"{str(comparison_img_urls)}"\n'
            task3_txt += f"{generated_img_url},{','.join(captions)}\n"
        
        with open(nsd_path / f'derivatives/reconstructions/{key}/task2_{task_version}.csv', 'w') as f:
            f.write(task2_txt)
        with open(nsd_path / f'derivatives/reconstructions/{key}/task3_{task_version}.csv', 'w') as f:
            f.write(task3_txt)
        with open(nsd_path / f'derivatives/reconstructions/{key}/task_{task_version}_info.json', 'w') as f:
            f.write(json.dumps(all_images))

subj01 run-003
subj02 run-003
subj03 run-003
subj04 run-003
subj05 run-003
subj06 run-003
subj07 run-003
subj08 run-003


In [157]:
# Create merged images for MTurk
model_name = 'clip-vit-large-patch14-text'
group_name = 'group-5'
run_names = ['run-002', 'run-003']
subjects = [f'subj0{i}' for i in range(1, 9)]

for run_name in run_names:
    for subject in subjects:
        reconstructions_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/images'
        reconstruction_files = [p for p in reconstructions_path.iterdir() if p.name != 'desktop.ini']
        reconstruction_files.sort(key=get_stim_id)
        stimulus_ids = np.array([get_stim_id(p) for p in reconstruction_files])
        unique_stimulus_ids = np.unique(stimulus_ids)
        
        for stim_id in unique_stimulus_ids:
            recon_ids = np.where(stimulus_ids == stim_id)[0]
            reconstruction_img = np.concatenate([
                np.array(Image.open(reconstruction_files[recon_id]))
                for recon_id in recon_ids
            ], axis=1)

            merged_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/merged'
            merged_path.mkdir(exist_ok=True, parents=True)
            Image.fromarray(reconstruction_img).save(merged_path / f'merged_stim-{stim_id}.png')

In [7]:
stimulus_images

<HDF5 dataset "imgBrick": shape (73000, 425, 425, 3), type "|u1">

In [12]:
# Create comparison images for fun
model_name = 'clip-vit-large-patch14-text'
group_name = 'group-5'
run_names = ['run-002', 'run-003']
subjects = [f'subj0{i}' for i in range(1, 9)]

for run_name in run_names:
    for subject in subjects:
        print(run_name, subject)
        reconstructions_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/images'
        reconstruction_files = [p for p in reconstructions_path.iterdir() if p.name != 'desktop.ini']
        reconstruction_files.sort(key=get_stim_id)
        stimulus_ids = np.array([get_stim_id(p) for p in reconstruction_files])
        unique_stimulus_ids = np.unique(stimulus_ids)
        
        for stim_id in unique_stimulus_ids:
            recon_ids = np.where(stimulus_ids == stim_id)[0]
            reconstruction_img = np.concatenate([np.array(Image.fromarray(stimulus_images[stim_id]).resize((512, 512)))] + [
                np.array(Image.open(reconstruction_files[recon_id]))
                for recon_id in recon_ids
            ], axis=1)

            merged_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/comparisons'
            merged_path.mkdir(exist_ok=True, parents=True)
            Image.fromarray(reconstruction_img).save(merged_path / f'merged_stim-{stim_id}.png')

run-002 subj01
run-002 subj02
run-002 subj03
run-002 subj04
run-002 subj05
run-002 subj06
run-002 subj07
run-002 subj08
run-003 subj01
run-003 subj02
run-003 subj03
run-003 subj04
run-003 subj05
run-003 subj06
run-003 subj07
run-003 subj08


In [65]:
# Initialize MTurk client

import boto3

region_name = 'us-east-1'

with open('aws_key.txt') as f:
    aws_access_key_id, aws_secret_access_key = f.read().split('\n')

#endpoint_url = 'https://mturk-requester-sandbox.us-east-1.amazonaws.com'
#image_hit_type_id = '370Y8Q6HBNLLDYRMFUVI4HD0I2711V'
#image_hit_layout_id = '3ES3DE1QF1UWHWRVU4Y10IHAID1HT8'
#caption_hit_type_id = '3K3YEJM75DID2H88D5MF0XGBQVL5WG'
#caption_hit_layout_id = '3DS6MGEQ9MKRE39D9DRS4I533TXEX8'

# Uncomment to use in production
endpoint_url = 'https://mturk-requester.us-east-1.amazonaws.com'
image_hit_type_id = '3Z1ME8JRSUPQWT8L1TOV35Y20G2OCL'
image_hit_layout_id = '3I13DBCGHBTC6JGD4D2XVBI4G6N6ZZ'
caption_hit_type_id = '3VU3J1KUFOCJQXTYZW6DN6P4XGBTGM'
caption_hit_layout_id = '3QZN5R01ALPXHDB9A3CBQAGE82AUGZ'

client = boto3.client(
    'mturk',
    endpoint_url=endpoint_url,
    region_name=region_name,
    aws_access_key_id=aws_access_key_id,
    aws_secret_access_key=aws_secret_access_key,
)

# This will return $10,000.00 in the MTurk Developer Sandbox
client.get_account_balance()

{'AvailableBalance': '2428.90',
 'ResponseMetadata': {'RequestId': 'd63fe8f5-929a-4889-988b-f1b608c7e448',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': 'd63fe8f5-929a-4889-988b-f1b608c7e448',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '30',
   'date': 'Mon, 19 Sep 2022 05:35:01 GMT'},
  'RetryAttempts': 0}}

In [66]:
import time
model_name = 'clip-vit-large-patch14-text'
group_name = 'group-5'
run_name = 'run-003'
#subjects = [f'subj0{i}' for i in range(1, 9)]
subjects = ['subj05', 'subj08']
task_version = 'version-2'
batch_name = 'batch-1'

num_assignments = 3
task_lifetime = 3600 * 24 * 7

for subject in subjects:
    print(subject)
    reconstructions_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/'
    task2_data = pd.read_csv(reconstructions_path / f'task2_{task_version}.csv')
    task3_data = pd.read_csv(reconstructions_path / f'task3_{task_version}.csv')
    with open(reconstructions_path / f'task_{task_version}_info.json') as f:
        task_info = json.loads(f.read())
        
    task2_hit_ids = []
    task3_hit_ids = []
    for i, (tinfo, t2, t3) in enumerate(zip(task_info, task2_data.iterrows(), task3_data.iterrows())):
        if i % 20 == 0:
            print(i)
        annotation = {
            'model_name': model_name,
            'group_name': group_name,
            'run_name': run_name,
            'subject': subject,
            'batch_name': batch_name,
            'task_version': task_version,
            'local_id': i,
        }
        t2_layout_params = [{'Name': k, 'Value': v} for k, v in dict(t2[1]).items()]
        t3_layout_params = [{'Name': k, 'Value': v.replace("'", "")} for k, v in dict(t3[1]).items()]
        
        time.sleep(0.1)
        while True:
            #
            task2_result = client.create_hit_with_hit_type(
                HITTypeId=image_hit_type_id,
                MaxAssignments=num_assignments,
                LifetimeInSeconds=task_lifetime,
                RequesterAnnotation=json.dumps({**annotation, 'task': 2}),
                HITLayoutId=image_hit_layout_id,
                HITLayoutParameters=t2_layout_params,
            )
            break
            #except:
            #    time.sleep(5)
        
        
        while True:
            #try:
            task3_result = client.create_hit_with_hit_type(
                HITTypeId=caption_hit_type_id,
                MaxAssignments=num_assignments,
                LifetimeInSeconds=task_lifetime,
                RequesterAnnotation=json.dumps({**annotation, 'task': 3}),
                HITLayoutId=caption_hit_layout_id,
                HITLayoutParameters=t3_layout_params,
            )
            break
            #except:
            #    time.sleep(5)
        
        task2_hit_ids.append(task2_result['HIT']['HITId'])
        task3_hit_ids.append(task3_result['HIT']['HITId'])
        
    with open(reconstructions_path / f'task2_{task_version}_{batch_name}_hits.txt', 'w') as f:
        f.write('\n'.join(task2_hit_ids))
    with open(reconstructions_path / f'task3_{task_version}_{batch_name}_hits.txt', 'w') as f: 
        f.write('\n'.join(task3_hit_ids))

subj05
0
20
40
60
80
subj08
0
20
40
60


In [73]:
# Count number of reviewable HITs
hits = []

next_token = None
while True:
    if next_token:
        result = client.list_reviewable_hits(NextToken=next_token, MaxResults=100)
    else:
        result = client.list_reviewable_hits(MaxResults=100)
    hits += result['HITs']
    
    if 'NextToken' in result:
        next_token = result['NextToken']
    else:
        break
    
print(len(hits))

1391


In [None]:
# Check status of HIT

model_name = 'clip-vit-large-patch14-text'
group_name = 'group-5'
run_name = 'run-003'
#subjects = [f'subj0{i}' for i in range(1, 9)]
subjects = ['subj05', 'subj08']
tasks = ['task2', 'task3']
batch_name = 'batch-1'
task_version = 'version-2'

for subject in subjects:
    print(subject)
    reconstructions_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/'
    for task in tasks:
        print(task)
    
        with open(reconstructions_path / f'{task}_{task_version}_{batch_name}_hits.txt') as f:
            task_hit_ids = f.read().split('\n')
            
        for hit_id in task_hit_ids:
            hit_assignments_result = client.list_assignments_for_hit(HITId=hit_id)
            
            for assignment in hit_assignments_result['Assignments']:
                task_workers.append({
                    'WorkerId': assignment['WorkerId'],
                    'Time': (assignment['SubmitTime'] - assignment['AcceptTime']).total_seconds(),
                })
                
                answer = ET.fromstring(client.get_assignment(AssignmentId=assignment['AssignmentId'])['Assignment']['Answer'])
                task_results.append(answer[0][1].text)
                
        continue
        with open(reconstructions_path / f'{task}_{task_version}_{batch_name}_results.txt', 'w') as f:
            f.write('\n'.join(task_results))
        with open(reconstructions_path / f'{task}_{task_version}_{batch_name}_workers.txt', 'w') as f:
            f.write(json.dumps(task_workers))


In [80]:
# Check status of a MTurk batch

import xml.etree.ElementTree as ET

model_name = 'clip-vit-large-patch14-text'
group_name = 'group-5'
run_name = 'run-003'
#subjects = [f'subj0{i}' for i in range(1, 9)]
subjects = ['subj05', 'subj08']
tasks = ['task2', 'task3']
batch_name = 'batch-1'
task_version = 'version-2'

for subject in subjects:
    print(subject)
    reconstructions_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/'
    for task in tasks:
        print(task)
    
        with open(reconstructions_path / f'{task}_{task_version}_{batch_name}_hits.txt') as f:
            task_hit_ids = f.read().split('\n')
            
        num_results = []
        
        for hit_id in task_hit_ids:
            hit_assignments_result = client.list_assignments_for_hit(HITId=hit_id)
            print(hit_assignments_result)
            break
            num_results.append(hit_assignments_result['NumResults'])
        print(num_results)
        print(f'completed={np.all(np.array(num_results) == 3)}')

subj05
task2
{'NextToken': 'p2:BRrak86qEuW6L3y2uJHA4RoJrl8rPwr4Rl+tj++U6VliVLAJ6r8M7yth4zGrBg==', 'NumResults': 3, 'Assignments': [{'AssignmentId': '3ITXP059PAXBZRVUECLBEIOVK8USJP', 'WorkerId': 'A1YSYI926BBOHW', 'HITId': '3HUR21WDE7363DEZ4KN0ZTOMQSTXYJ', 'AssignmentStatus': 'Submitted', 'AutoApprovalTime': datetime.datetime(2022, 9, 21, 23, 53, 10, tzinfo=tzlocal()), 'AcceptTime': datetime.datetime(2022, 9, 18, 23, 52, 54, tzinfo=tzlocal()), 'SubmitTime': datetime.datetime(2022, 9, 18, 23, 53, 10, tzinfo=tzlocal()), 'Answer': '<?xml version="1.0" encoding="ASCII"?><QuestionFormAnswers xmlns="http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd"><Answer><QuestionIdentifier>selected_image_idx</QuestionIdentifier><FreeText>0</FreeText></Answer></QuestionFormAnswers>'}, {'AssignmentId': '33CUSNVVN1Q4WQK29AIF81FGS4U88O', 'WorkerId': 'A3CJVRJ34U70Y9', 'HITId': '3HUR21WDE7363DEZ4KN0ZTOMQSTXYJ', 'AssignmentStatus': 'Submitted', 'AutoApprovalTime':

In [84]:
# Save results of MTurk batch

import xml.etree.ElementTree as ET

model_name = 'clip-vit-large-patch14-text'
group_name = 'group-5'
run_name = 'run-003'
#subjects = [f'subj0{i}' for i in range(1, 9)]
subjects = ['subj05', 'subj08']
tasks = ['task2', 'task3']
batch_name = 'batch-1'
task_version = 'version-2'

for subject in subjects:
    print(subject)
    reconstructions_path = nsd_path / f'derivatives/reconstructions/{model_name}/{group_name}/{run_name}/{subject}/'
    for task in tasks:
        print(task)
    
        with open(reconstructions_path / f'{task}_{task_version}_{batch_name}_hits.txt') as f:
            task_hit_ids = f.read().split('\n')
            
        task_results = []
        task_workers = []
        for hit_id in task_hit_ids:
            hit_assignments_result = client.list_assignments_for_hit(HITId=hit_id)
            
            assignment_results = []
            assignment_workers = []
            
            for assignment in hit_assignments_result['Assignments']:
                assignment_workers.append({
                    'WorkerId': assignment['WorkerId'],
                    'Time': (assignment['SubmitTime'] - assignment['AcceptTime']).total_seconds(),
                })
                
                answer = ET.fromstring(client.get_assignment(AssignmentId=assignment['AssignmentId'])['Assignment']['Answer'])
                assignment_results.append(answer[0][1].text)
            task_results.append(assignment_results)
            task_workers.append(assignment_workers)

        with open(reconstructions_path / f'{task}_{task_version}_{batch_name}_results.txt', 'w') as f:
            f.write(json.dumps(task_results))
        with open(reconstructions_path / f'{task}_{task_version}_{batch_name}_workers.txt', 'w') as f:
            f.write(json.dumps(task_workers))


subj05
task2
task3
subj08
task2
task3


In [20]:
assignment

{'AssignmentId': '3S4AW7T80PWYSWHJRW16WULH9KM4LR',
 'WorkerId': 'A3CJVRJ34U70Y9',
 'HITId': '3VQTAXTYOGZI91RTPNYT7BRHBQBUB4',
 'AssignmentStatus': 'Submitted',
 'AutoApprovalTime': datetime.datetime(2022, 9, 17, 1, 49, 3, tzinfo=tzlocal()),
 'AcceptTime': datetime.datetime(2022, 9, 14, 1, 48, 45, tzinfo=tzlocal()),
 'SubmitTime': datetime.datetime(2022, 9, 14, 1, 49, 3, tzinfo=tzlocal()),
 'Answer': '<?xml version="1.0" encoding="ASCII"?><QuestionFormAnswers xmlns="http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd"><Answer><QuestionIdentifier>selected_image_idx</QuestionIdentifier><FreeText>1</FreeText></Answer></QuestionFormAnswers>'}

In [13]:
assignment

{'AssignmentId': '3XXU1SWE8090XP8EB4PEBDFTYWDA05',
 'WorkerId': 'AORHXBTOCXFUK',
 'HITId': '3D1UCPY6HTNF89G37RIIT0BGQ7883D',
 'AssignmentStatus': 'Submitted',
 'AutoApprovalTime': datetime.datetime(2022, 9, 17, 2, 38, 1, tzinfo=tzlocal()),
 'AcceptTime': datetime.datetime(2022, 9, 14, 2, 37, 36, tzinfo=tzlocal()),
 'SubmitTime': datetime.datetime(2022, 9, 14, 2, 38, 1, tzinfo=tzlocal()),
 'Answer': '<?xml version="1.0" encoding="ASCII"?><QuestionFormAnswers xmlns="http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd"><Answer><QuestionIdentifier>category.label</QuestionIdentifier><FreeText>A surfer in a wetsuit rides on a wave.</FreeText></Answer></QuestionFormAnswers>'}

In [120]:
import xml.etree.ElementTree as ET

hit_assignments_result = client.list_assignments_for_hit(HITId='38RHULDVABT5ODU4QX0T5QRQN22WIO')
for assignment in hit_assignments_result['Assignments']:
    print(client.get_assignment(AssignmentId=assignment['AssignmentId'])['Assignment']['Answer'])
    answer = ET.fromstring(client.get_assignment(AssignmentId=assignment['AssignmentId'])['Assignment']['Answer'])
    print(answer[0][1].text)


<?xml version="1.0" encoding="ASCII"?><QuestionFormAnswers xmlns="http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2005-10-01/QuestionFormAnswers.xsd"><Answer><QuestionIdentifier>category.label</QuestionIdentifier><FreeText>A person holding a surfboard walking on the watery beach.</FreeText></Answer></QuestionFormAnswers>
A person holding a surfboard walking on the watery beach.


In [188]:
client.get_hit(HITId='366FYU4PUT32D8Y150RZW0Z382REK6')

{'HIT': {'HITId': '366FYU4PUT32D8Y150RZW0Z382REK6',
  'HITTypeId': '3K3YEJM75DID2H88D5MF0XGBQVL5WG',
  'HITGroupId': '3PBTVBPQ96GCMSALE03YLMKESSNLGB',
  'HITLayoutId': '3TVX9BGUNUTQFYGJYRQ8P8PWY5PH6W',
  'CreationTime': datetime.datetime(2022, 9, 14, 1, 16, 39, tzinfo=tzlocal()),
  'Title': 'Choose a caption that best describes the generated image.',
  'Description': 'Choose a caption that best describes the generated image.',
  'Question': '<?xml version="1.0"?>\n<HTMLQuestion xmlns="http://mechanicalturk.amazonaws.com/AWSMechanicalTurkDataSchemas/2011-11-11/HTMLQuestion.xsd">\n  <HTMLContent><![CDATA[<html><head><title>HIT</title><meta http-equiv="Content-Type" content="text/html; charset=UTF-8"/></head><body><script src="https://assets.crowd.aws/crowd-html-elements.js"></script>\r\n<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet"\r\n      integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3" crossorig

In [63]:
# PURGE HITS

from pprint import pprint
from datetime import datetime

next_token = None
while True:
    if next_token:
        result = client.list_hits(NextToken=next_token, MaxResults=100)
    else:
        result = client.list_hits(MaxResults=100)
    
    if 'NextToken' in result:
        next_token = result['NextToken']
    hits = result['HITs']
    print(len(hits))
    if len(hits) == 0:
        break
    
    for hit in hits:
        hit_id = hit['HITId']
        status = hit['HITStatus']
        print(f'{hit_id=}, {status=}')

        if status == 'Assignable':
            response = client.update_expiration_for_hit(
                HITId=hit_id,
                ExpireAt=datetime(2015, 1, 1)
            )
            
        if status == 'Reviewable':
            for assignment in client.list_assignments_for_hit(HITId=hit_id)['Assignments']:
                try:
                    print(f'Approving assignment {assignment["AssignmentId"]}')
                    client.approve_assignment(AssignmentId=assignment['AssignmentId'])
                except:
                    print('Failed')

        # Delete the HIT
        try:
            client.delete_hit(HITId=hit_id)
        except:
            print('Not deleted')
        else:
            print('Deleted')

0
