In [1]:
import json
from pathlib import Path
import pickle
import numpy as np
import datetime
import os
import random

In [2]:
split = {
  "train": 0.4,
  "val": 0.2,
  "test": 0.4,
}

In [3]:
num_data = 500
seed = 42
current_datetime = datetime.datetime.now()

In [4]:
random.seed(seed)
np.random.seed(seed)

In [5]:
assert sum(split.values()) == 1, "split values must sum to 1.0"

In [6]:
label_mapping = {
  "(high-risk) screening": "(high-risk) screening",
  "6-month follow-up / surveillance": "6-month follow-up / surveillance",
  "additional workup": "additional workup",
  "exclude": "exclude",
  "extent of disease / pre-operative planning": "extent of disease / pre-operative planning",
  "unknown": "unknown",
  "treatment monitoring": "exclude",
}

In [7]:
def extract_info(data):
  text = data['task']['data']
  id = data['task']['id']
  found = False
  for result in data['result']:
    label_name = result['from_name']
    if label_name == 'indication':
      found = True
      label = result['value']['choices'][0]
      assert len(result['value']['choices']) == 1, f'more than one label selected for task {id}'
  if not found:
    print(f'no indication label found for task {id}')
    return None
  else:
    return {
      'id': id,
      'text': text,
      'label': label_mapping[label]
    }

In [8]:
data_dir = Path('/gpfs/data/geraslab/ekr6072/projects/study_indication/data')
output_path = data_dir / 'dataset.pkl'

In [13]:
data_paths = list(data_dir.rglob('label_studio/*.json'))

In [15]:
dataset = []
for data_path in data_paths:
  with open(data_path, 'r') as f: 
    task = json.load(f)
  data = extract_info(task)
  if data is not None:
    dataset.append(data)

no indication label found for task 3858
no indication label found for task 3814
no indication label found for task 3899
no indication label found for task 3871


In [74]:
ids = []
for data in dataset:
  id = data['id']
  ids.append(id)

In [75]:
sort_indices = np.argsort(ids)

In [76]:
indices = np.random.permutation(sort_indices)

In [77]:
train_max_id = int(split['train'] * num_data)
val_max_id = int(split['val'] * num_data) + train_max_id

In [78]:
train_indices = indices[:train_max_id]
val_indices = indices[train_max_id:val_max_id]
test_indices = indices[val_max_id:]

In [79]:
def get_items(dataset, indices):
  dataset = np.array(dataset)
  return list(dataset[indices])

In [80]:
train_ds = get_items(dataset, train_indices)
val_ds = get_items(dataset, val_indices)
test_ds = get_items(dataset, test_indices)

In [87]:
dataset = {
  "train": train_ds,
  "val": val_ds,
  "test": test_ds
}

In [88]:
with open(output_path, 'wb') as f:
  pickle.dump(dataset, f)

In [89]:
date_string = datetime.datetime.strftime(current_datetime, "%Y%m%d")
time_string = datetime.datetime.strftime(current_datetime, "%H%M%S")
datetime_string = datetime.datetime.strftime(current_datetime, "%Y%m%dT%H%M%S")

In [90]:
log_dir = data_dir / 'logs' / date_string / time_string

In [91]:
os.makedirs(log_dir)

In [92]:
f = {}
for subset in dataset: 
  f[subset] = open(os.path.join(log_dir, f'{subset}_task_ids.log'), 'w')

In [93]:
for subset in dataset:
  f[subset].write("task_id\n")
  for task in dataset[subset]:
    id = task["id"]
    f[subset].write(f"{id}\n")

In [94]:
for subset in dataset: 
  f[subset].close()

In [95]:
import re
pattern = re.compile(r'ACCESSION_NUMBER: ([A-Z]*[0-9]*)\n')

In [99]:
f = {}
for subset in dataset: 
  f[subset] = open(os.path.join(log_dir, f'{subset}_acns.log'), 'w')

In [100]:
for subset in dataset:
  f[subset].write("acn\n")
  for data in dataset[subset]:
    meta = data['text']['meta']
    acn = pattern.findall(meta)
    assert len(acn) == 1, 'invalid number of accession numbers found in metadata'
    f[subset].write(f"{acn[0]}\n")

In [101]:
for subset in dataset: 
  f[subset].close()