Notebook for rendering a tikz plot of the human-annotated phases
===

In [2]:
%matplotlib inline

import os
import numpy as np
import pandas as pd
import itertools

import matplotlib.pyplot as plt
import matplotlib.dates as md
import matplotlib
import pylab as pl

import datetime as dt
import time

from collections import Counter

import json
import os
import re
import random
import itertools
import multiprocessing as mp
from IPython.core.display import display, HTML
import datetime as dt

import sqlite3
from nltk import word_tokenize
from html.parser import HTMLParser
from tqdm import tqdm

In [1]:
import sys
sys.path.append("../annotation_data")
from phase import *

In [3]:
# new approach: load annotations directly using a convenience method
annotated_df = get_annotated_phase_df()
len(annotated_df)

0it [00:00, ?it/s]



21it [00:03,  6.39it/s]


Site 135571 lacks 302 journals with phase annotations; http://127.0.0.1:5000/siteId/135571
It will be excluded unless the whole site is coded for phases!


52it [00:11,  4.70it/s]


Site 186533 lacks 118 journals with phase annotations; http://127.0.0.1:5000/siteId/186533
It will be excluded unless the whole site is coded for phases!


155it [00:49,  3.12it/s]


Site 722429 lacks 2 journals with phase annotations; http://127.0.0.1:5000/siteId/722429
It will be excluded unless the whole site is coded for phases!


162it [00:51,  3.16it/s]


Site 818475 lacks 5 journals with phase annotations; http://127.0.0.1:5000/siteId/818475
It will be excluded unless the whole site is coded for phases!


167it [00:52,  3.15it/s]

1 non-trivial journals did not have annotations on site 826591 and were skipped.


185it [00:59,  3.11it/s]

2 non-trivial journals did not have annotations on site 857627 and were skipped.


189it [01:00,  3.11it/s]


Site 864283 lacks 7 journals with phase annotations; http://127.0.0.1:5000/siteId/864283
It will be excluded unless the whole site is coded for phases!
1 non-trivial journals did not have annotations on site 866641 and were skipped.


214it [01:08,  3.14it/s]

1 non-trivial journals did not have annotations on site 912713 and were skipped.


228it [01:12,  3.12it/s]

1 non-trivial journals did not have annotations on site 1026937 and were skipped.


229it [01:13,  3.11it/s]

1 non-trivial journals did not have annotations on site 1028140 and were skipped.


231it [01:14,  3.10it/s]

2 non-trivial journals did not have annotations on site 1031202 and were skipped.


237it [01:15,  3.14it/s]


9336

In [36]:
unks = np.apply_along_axis(lambda row: ((row < 0.9)&(row > 0)).any(), 1, annotated_df[[phase_label + "_score" for phase_label in phase_labels]].values)
np.sum(unks)

0

In [39]:
transitions = np.apply_along_axis(lambda row: np.sum(row) > 1, 1, annotated_df[[phase_label + "_score" for phase_label in phase_labels]].values)
np.sum(transitions) / len(annotated_df)

0.015852613538988862

In [2]:
annotation_web_client_database = "/home/srivbane/shared/caringbridge/data/projects/qual-health-journeys/instance/cbAnnotator.sqlite"


def get_annotation_db():
    db = sqlite3.connect(
            annotation_web_client_database,
            detect_types=sqlite3.PARSE_DECLTYPES
        )
    db.row_factory = sqlite3.Row
    return db


def get_phase_annotations():
    try:
        db = get_annotation_db()
        cursor = db.execute(
            """SELECT site_id, journal_oid, data 
                FROM journalAnnotation 
                WHERE annotation_type = "journal_journey_phase"
                GROUP BY site_id, journal_oid
                ORDER BY id DESC"""
        )
        journal_phase_annotations = cursor.fetchall()
        data = [{'site_id': a['site_id'], 
                 'journal_oid': a['journal_oid'],
                 'data': a['data']} 
                for a in journal_phase_annotations]
        return data
    finally:
        db.close()

In [3]:
phase_annotations = get_phase_annotations()
len(phase_annotations)

7660

In [4]:
phase_annotations[0]

{'site_id': 190464,
 'journal_oid': '51be29d46ca004e47900f6a7',
 'data': 'treatment|unknown'}

In [8]:
data = annotated_df.phases.apply(lambda phase_list: "|".join(phase_list))
annotated_df['data'] = data

In [5]:
# get valid sites
# Note: Currently using a loosely-filtered list, ignoring patient vs caregiver differences
working_dir = "/home/srivbane/shared/caringbridge/data/projects/qual-health-journeys/identify_candidate_sites"
valid_classification_sites_filename = os.path.join(working_dir, "valid_sites_with_75_pct_patient_journals.txt")
with open(valid_classification_sites_filename, 'r') as infile:
    valid_sites = [int(line.strip()) for line in infile.readlines() if line.strip() != ""]
len(valid_sites)

16564

In [6]:
# trim out annotations on sites not in the candidate set

site_ids_not_in_candidate_sites = []
phase_annotations_filtered = []
for a in phase_annotations:
    site_id = a['site_id']
    if site_id not in valid_sites:
        site_ids_not_in_candidate_sites.append(site_id)
    else:
        phase_annotations_filtered.append(a)
len(phase_annotations), len(phase_annotations_filtered)

(7660, 7176)

In [7]:
Counter(site_ids_not_in_candidate_sites).most_common()

[(882253, 44),
 (966641, 39),
 (857627, 35),
 (867427, 35),
 (1023926, 29),
 (1015767, 27),
 (1028140, 25),
 (846593, 23),
 (1022397, 22),
 (912454, 21),
 (864283, 20),
 (866641, 19),
 (845819, 19),
 (1024634, 18),
 (1010066, 17),
 (987982, 16),
 (1006567, 16),
 (818475, 16),
 (1026937, 15),
 (839971, 11),
 (1065713, 3),
 (910338, 3),
 (816060, 2),
 (540483, 1),
 (592345, 1),
 (104723, 1),
 (810758, 1),
 (889656, 1),
 (890520, 1),
 (896073, 1),
 (103172, 1),
 (1113912, 1)]

In [8]:
# These functions get journal metadata info
def get_db():
    journal_wd="/home/srivbane/shared/caringbridge/data/projects/qual-health-journeys/extract_site_features"
    db_filename = os.path.join(journal_wd, "journal.db")
    db = sqlite3.connect(
            db_filename,
            detect_types=sqlite3.PARSE_DECLTYPES
        )
    db.row_factory = sqlite3.Row
    return db

def get_journal_info(site_id):
    site_id = int(site_id)
    try:
        db = get_db()
        cursor = db.execute("""SELECT id, journal_oid, title, body 
                                FROM journal 
                                WHERE site_id = ?
                                ORDER BY createdAt""", 
                            (site_id,))
        journals = cursor.fetchall()
        assert journals is not None
        
        def get_length(title, body):
            title_len = 0 if title is None else len(title)
            body_len = 0 if body is None else len(body)
            return title_len + body_len
        
        journal_dicts = [{'site_id': site_id,
                        'journal_oid': j['journal_oid'],
                        'length': get_length(j['title'], j['body'])}
                      for j in journals]
        return journal_dicts
    finally:
        db.close()

In [10]:
# newer method to compute the phase_lists from the dataframe
phase_lists = []
for site_id, journals in tqdm(annotated_df.groupby('site_id')):
    phases = journals.data.values
    phase_lists.append(phases)

100%|██████████| 203/203 [00:00<00:00, 3442.51it/s]


In [9]:
site_id_count = 0
total_journal_count = 0

phase_lists = []

def sort_by_site_id(tup):
    return tup['site_id']
phase_annotations_filtered = sorted(phase_annotations_filtered, key=sort_by_site_id)
for k, g in tqdm(itertools.groupby(phase_annotations_filtered, sort_by_site_id)):
    site_id = k
    site_annotations = list(g)
    if len(site_annotations) < 5:
        continue  # only want complete sites, and sites with less than five definitely aren't complete
    
    journals = get_journal_info(site_id)
    assert len(site_annotations) <= len(journals)
        
    journals = [j for j in journals if j['length'] >= 50]
    unannotated_journals = len(journals) - len(site_annotations)
    if unannotated_journals > 1:
        print("Site %d lacks %d journals with phase annotations" % (site_id, unannotated_journals))
        continue  # not every non-trivial journal on this site is coded
    
    annotation_dict = {a['journal_oid']: a['data'] for a in site_annotations}
    phases = []
    for j in journals:
        if j['journal_oid'] not in annotation_dict:
            continue
        phases.append(annotation_dict[j['journal_oid']])
    #if len(phases) < len(site_annotations) - 1:
    #    print("After matching journals to phase annotations, there were many unmatched annotations!")
    #    print(site_id)
    #    print(len(phases), len(site_annotations))
    #    assert False
    
    #print(phases)
    
    total_journal_count += len(phases)
    site_id_count += 1
    
    phase_lists.append(phases)
    

site_id_count, total_journal_count

19it [00:02,  8.11it/s]

Site 135571 lacks 302 journals with phase annotations


50it [00:09,  5.34it/s]

Site 186533 lacks 118 journals with phase annotations


140it [00:50,  2.79it/s]

Site 623581 lacks 1364 journals with phase annotations


150it [01:06,  2.26it/s]

Site 714287 lacks 519 journals with phase annotations


151it [01:07,  2.25it/s]

Site 722429 lacks 2 journals with phase annotations


205it [01:58,  1.73it/s]


(183, 6945)

In [11]:
Counter([phase for phases in phase_lists for phase in phases]).most_common()

[('treatment', 7700),
 ('pretreatment', 557),
 ('cured', 448),
 ('treatment|unknown', 170),
 ('unknown', 121),
 ('end_of_life', 117),
 ('pretreatment|treatment|unknown', 50),
 ('treatment|cured', 42),
 ('cured|unknown', 37),
 ('treatment|cured|unknown', 32),
 ('pretreatment|unknown', 32),
 ('pretreatment|treatment', 16),
 ('end_of_life|unknown', 6),
 ('treatment|end_of_life', 4),
 ('treatment|end_of_life|unknown', 3),
 ('pretreatment|cured', 1)]

In [12]:
def get_labels_from_phase_string(phase):
    labels = phase.split('|')
    if 'unknown' in labels:
        labels.remove('unknown')
    if 'screening' in labels or 'info_seeking' in labels:
        if 'screening' in labels: labels.remove('screening')
        if 'info_seeking' in labels: labels.remove('info_seeking')
        labels.insert(0, 'pretreatment')
    # TODO Should ensure the labels are sorted by phase temporality here...
    assert len(labels) <= 2  # a given journal should never have more than 2 phases tagged
    return labels


def get_initial_label(phases):
    if len(phases) == 0:
        return None
    phase = phases[0]
    labels = get_labels_from_phase_string(phase)
    if len(labels) == 0:
        return get_initial_label(phases[1:])
    elif len(labels) == 1:
        return labels[0]
    elif len(labels) > 1:
        return seek_phase_forward(labels, phases, 1)
    
    
def get_final_label(phases):
    if len(phases) == 0:
        return None
    phase = phases[-1]
    labels = get_labels_from_phase_string(phase)
    if len(labels) == 0:
        return get_initial_label(phases[:-1])
    elif len(labels) == 1:
        return labels[0]
    elif len(labels) > 1:
        return seek_phase_backward(labels, phases, -2)

    
def get_transitions(phases):
    transitions = []
    prev_labels = None
    for i in range(len(phases)):
        phase = phases[i]
        curr_labels = get_labels_from_phase_string(phase)
        if len(curr_labels) == 0:
            continue
        elif prev_labels is None:
            prev_labels = curr_labels
            continue
        elif prev_labels == curr_labels:
            if len(curr_labels) > 1:  # this is the nightmare scenario...
                # What SHOULD happen here is we base the entire chain of transitions here as 
                # being the NEXT transition that happens... unless this appears at the end of the phase string,
                # at which point it's based on the PREVIOUS transition that happens
                print("Warning: Ambiguous transition.")
            transition = (prev_labels[0], curr_labels[0])
            transitions.append(transition)
        else:
            # There is a difference! A transition must have occurred
            transition = (prev_labels[-1], curr_labels[0])
            if len(prev_labels) == 2 and len(curr_labels) == 1:
                curr_label = curr_labels[0]
                if curr_label in prev_labels:
                    prev_labels.remove(curr_label)
                    transition = (prev_labels[0], curr_label)
                else:
                    print("Warning: Two labels to 1 different label; assuming latest phase.")
                    prev_label = prev_labels[-1]
                    if prev_label == 'cured':
                        print("  This is the specifically bad case where 'cured' is linked to something other than treatment!")
                    transition = (prev_label, curr_label)
            elif len(prev_labels) == 1 and len(curr_labels) == 2:
                prev_label = prev_labels[0]
                if prev_label in curr_labels:
                    curr_labels_copy = curr_labels[:]
                    curr_labels_copy.remove(prev_label)
                    curr_label = curr_labels_copy[0]
                    transition = (prev_label, curr_label)
                else:
                    print("Warning: 1 label to 2 different labels. Assuming earliest phase.")
                    curr_label = curr_labels[0]
                    transition = (prev_label, curr_label)
            elif len(prev_labels) == 2 and len(curr_labels) == 2:
                print("Warning: 2 labels to 2 different labels. I hope this never happens!")
                print("  Assuming transition:", transition, prev_labels, curr_labels)
            transitions.append(transition)
        prev_labels = curr_labels
    return transitions
    
def seek_phase_forward(prev_labels, phases, index):
    if index >= len(phases):
        print("Warning: Reached end of list while attempting to resolve ambiguous phase transition.")
        return prev_labels[0]
    curr_phase = phases[index]
    curr_labels = get_labels_from_phase_string(curr_phase)
    
    if set(prev_labels) == set(curr_labels):
        # no new information, keep seeking
        return seek_phase_forward(curr_labels, phases, index + 1)
    elif len(curr_labels) == 0:  # no information from this journal
        return seek_phase_forward(prev_labels, phases, index + 1)
    elif len(curr_labels) == 1:  # we found a single-label journal
        curr_label = curr_labels[0]
        if curr_label in prev_labels:
            # return the OTHER label
            prev_labels.remove(curr_label)
            target_label = prev_labels[-1]
        else:
            print("Warning: Ambiguous transition from 2 phases to 1 phase.")
            target_label = prev_labels[-1]  # in this situation, assume it was the latest of the phases
        return target_label
    elif len(curr_labels) >= 2:
        print("Warning: This should never happen. (2-label journal followed by a journal with a different two labels.)")
        return None
    
    
def seek_phase_backward(prev_labels, phases, index):
    if index * -1 > len(phases):
        print("Warning: Reached start of list while attempting to resolve ambiguous phase transition.")
        return prev_labels[-1]
    curr_phase = phases[index]
    curr_labels = get_labels_from_phase_string(curr_phase)
    
    if set(prev_labels) == set(curr_labels):
        # no new information, keep seeking
        return seek_phase_backward(curr_labels, phases, index - 1)
    elif len(curr_labels) == 0:  # no information from this journal
        return seek_phase_backward(prev_labels, phases, index - 1)
    elif len(curr_labels) == 1:  # we found a single-label journal
        curr_label = curr_labels[0]
        if curr_label in prev_labels:
            # return the OTHER label
            prev_labels.remove(curr_label)
            target_label = prev_labels[-1]
        else:
            print("Warning: Ambiguous transition from 2 phases to 1 phase.")
            target_label = prev_labels[0]  # in this situation, assume it was the first of the phases
            # basically, this is like (pre) -> (treatment|eol); we assume and return (treatment) here as the final phase
        return target_label
    elif len(curr_labels) >= 2:
        print("Warning: This should never happen. (2-label journal preceded by a journal with a different two labels.)")
        return None

In [13]:
# summarize the annotations
annotation_phase_counts = []
for phases in phase_lists:
    for phase in phases:
        labels = get_labels_from_phase_string(phase)
        annotation_phase_counts.append(len(labels))
Counter(annotation_phase_counts).most_common()

[(1, 9067), (2, 148), (0, 121)]

In [14]:
unknown_count = 0
total = 0
for phases in phase_lists:
    for phase in phases:
        if 'unknown' in phase:
            unknown_count += 1
        total += 1
unknown_count, unknown_count / total

(451, 0.0483076263924593)

In [15]:
initial_phases = []
final_phases = []
for phases in phase_lists:
    initial_label = get_initial_label(phases)
    if initial_label is not None:
        initial_phases.append(initial_label)
    
    final_label = get_final_label(phases)
    if final_label is not None:
        final_phases.append(final_label)
    
len(initial_phases), len(final_phases)



(203, 203)

In [16]:
initial_phase_counts = Counter(initial_phases)
initial_phase_counts.most_common()

[('pretreatment', 114), ('treatment', 88), ('end_of_life', 1)]

In [17]:
final_phase_counts = Counter(final_phases)
final_phase_counts.most_common()

[('cured', 87), ('end_of_life', 55), ('treatment', 55), ('pretreatment', 6)]

In [18]:
phase_labels = ['pretreatment', 'treatment', 'end_of_life', 'cured']
valid_transitions=[('pretreatment', 'treatment'),
                   ('treatment', 'end_of_life'),
                   ('treatment', 'cured'),
                   ('cured', 'treatment')]
# in addition, self-loops are always allowed
valid_transitions.extend([(phase, phase) for phase in phase_labels])

In [19]:
initial={}
final={}
for phase_label in phase_labels:
    initial_phase_count = initial_phase_counts[phase_label] if phase_label in initial_phase_counts else 0
    initial_ratio = initial_phase_count / len(initial_phases)
    rendering = "I %.1f\\%%" % (initial_ratio * 100)
    print('initial', phase_label, rendering)
    initial[phase_label] = rendering
    
    final_phase_count = final_phase_counts[phase_label] if phase_label in final_phase_counts else 0
    final_ratio = final_phase_count / len(final_phases)
    rendering = "F %.1f\\%%" % (final_ratio * 100)
    print('final', phase_label, rendering)
    final[phase_label] = rendering

initial pretreatment I 56.2\%
final pretreatment F 3.0\%
initial treatment I 43.3\%
final treatment F 27.1\%
initial end_of_life I 0.5\%
final end_of_life F 27.1\%
initial cured I 0.0\%
final cured F 42.9\%


#### Compute the intermediate transitions

In [20]:
all_transitions = []
site_transitions = []
for phases in phase_lists:
    transitions = get_transitions(phases)
    all_transitions.extend(transitions)
    site_transitions.append(transitions)

  This is the specifically bad case where 'cured' is linked to something other than treatment!
  Assuming transition: ('cured', 'treatment') ['pretreatment', 'cured'] ['treatment', 'cured']


In [21]:
# print sites that have invalid transitions
for transitions in site_transitions:
    if any([transition not in valid_transitions for transition in transitions]):
        for transition in transitions:
            if transition in valid_transitions:
                print("  Valid", transition)
            else:
                print("Invalid", transition)
        print()

  Valid ('pretreatment', 'pretreatment')
  Valid ('pretreatment', 'pretreatment')
  Valid ('pretreatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Valid ('treatment', 'treatment')
  Val

In [22]:
site_transitions = [t for t in site_transitions if len(t) > 3]

In [23]:
# Visualize each site with the removal of its repeat transitions
trimmed_transitions = []
for transitions in site_transitions:
    trimmed = [transitions[i] for i in range(len(transitions) - 1) if transitions[i][0] != transitions[i][1] or transitions[i] != transitions[i+1]]
    if len(trimmed) == 0 or transitions[-1] != trimmed[-1]:
        trimmed.append(transitions[-1])
    trimmed_transitions.append(tuple(trimmed))
Counter(trimmed_transitions).most_common()

[((('treatment', 'treatment'), ('treatment', 'end_of_life')), 19),
 ((('treatment', 'treatment'),), 18),
 ((('pretreatment', 'pretreatment'),
   ('pretreatment', 'treatment'),
   ('treatment', 'treatment')),
  18),
 ((('pretreatment', 'pretreatment'),
   ('pretreatment', 'treatment'),
   ('treatment', 'treatment'),
   ('treatment', 'cured'),
   ('cured', 'cured')),
  15),
 ((('pretreatment', 'pretreatment'),
   ('pretreatment', 'treatment'),
   ('treatment', 'treatment'),
   ('treatment', 'cured')),
  15),
 ((('treatment', 'treatment'),
   ('treatment', 'end_of_life'),
   ('end_of_life', 'end_of_life')),
  11),
 ((('treatment', 'treatment'), ('treatment', 'cured')), 10),
 ((('pretreatment', 'treatment'), ('treatment', 'treatment')), 8),
 ((('treatment', 'treatment'), ('treatment', 'cured'), ('cured', 'cured')), 6),
 ((('pretreatment', 'pretreatment'),
   ('pretreatment', 'treatment'),
   ('treatment', 'treatment'),
   ('treatment', 'end_of_life')),
  6),
 ((('pretreatment', 'treatment'

In [24]:
all_transition_counts = Counter(all_transitions)
all_transition_counts.most_common()

[(('treatment', 'treatment'), 7657),
 (('pretreatment', 'pretreatment'), 501),
 (('cured', 'cured'), 379),
 (('pretreatment', 'treatment'), 142),
 (('treatment', 'cured'), 137),
 (('end_of_life', 'end_of_life'), 69),
 (('treatment', 'end_of_life'), 57),
 (('cured', 'treatment'), 41),
 (('treatment', 'pretreatment'), 22),
 (('cured', 'pretreatment'), 4),
 (('pretreatment', 'cured'), 3)]

In [25]:
# compute each of the transition probabilities
transition_counts_list = []
for phase_label_1 in phase_labels:
    transition_counts = []
    for phase_label_2 in phase_labels:
        transition = (phase_label_1, phase_label_2)
        if transition in all_transition_counts:
            transition_count = all_transition_counts[transition]
        else:
            transition_count = 0
        transition_counts.append(transition_count)
    transition_counts_list.append(transition_counts)
            
transition_counts_list

[[501, 142, 0, 3], [22, 7657, 57, 137], [0, 0, 69, 0], [4, 41, 0, 379]]

In [26]:
# range for thickness is 0.1mm to 0.8mm
min_thickness = 0.1
max_thickness = 0.6
thickness_range = max_thickness - min_thickness

transitions = {}
transition_thickness = {} 
for i, phase_label_1 in enumerate(phase_labels):
    total_outgoing_transitions = sum(transition_counts_list[i])
    transitions[phase_label_1] = {}
    transition_thickness[phase_label_1] = {}
    for j, phase_label_2 in enumerate(phase_labels):
        pct = transition_counts_list[i][j] / total_outgoing_transitions
        rendering = "%.1f\\%%" % (pct * 100)
        transitions[phase_label_1][phase_label_2] = rendering
        
        thickness = min_thickness + (thickness_range * pct)
        thickness_rendering = "%.2fmm" % thickness
        transition_thickness[phase_label_1][phase_label_2] = thickness_rendering
transitions

{'pretreatment': {'pretreatment': '77.6\\%',
  'treatment': '22.0\\%',
  'end_of_life': '0.0\\%',
  'cured': '0.5\\%'},
 'treatment': {'pretreatment': '0.3\\%',
  'treatment': '97.3\\%',
  'end_of_life': '0.7\\%',
  'cured': '1.7\\%'},
 'end_of_life': {'pretreatment': '0.0\\%',
  'treatment': '0.0\\%',
  'end_of_life': '100.0\\%',
  'cured': '0.0\\%'},
 'cured': {'pretreatment': '0.9\\%',
  'treatment': '9.7\\%',
  'end_of_life': '0.0\\%',
  'cured': '89.4\\%'}}

### Fill the tikz template

The template is tikz picture code to render a flowchart of the phases with annotation information inserted

In [27]:
from jinja2 import Template

tikz_template_string = """
    % Place nodes
    \\node[label={[align=left,anchor=south west,shift={(block.north west)},inner sep=2pt]{<<initial['pretreatment']>>\\\\ <<final['pretreatment']>>}}] [block] (pre) {Pre-treatment};
    \\node[label={[align=left,anchor=south west,shift={(block.north west)},inner sep=2pt]{<<initial['treatment']>>\\\\ <<final['treatment']>>}}] [block, below of=pre, node distance=2.2cm] (treatment) {Acute treatment};
    \\node[label={[align=left,anchor=south west,shift={(block.north west)},inner sep=2pt]{<<initial['end_of_life']>>\\\\ <<final['end_of_life']>>}}] [block, below of=treatment, node distance=3cm] (eol) {End-of-life};
    \\node[label={[align=left,anchor=south west,shift={(block.north west)},inner sep=2pt]{<<initial['cured']>>\\\\ <<final['cured']>>}}] [block, right of=eol, node distance=4cm] (noed) {No evidence of disease};
    % Draw edges
    \path[->] (pre) edge [loop left, distance=1cm, line width=<<transition_thickness['pretreatment']['pretreatment']>>] node {<<transitions['pretreatment']['pretreatment']>>} (pre);
    \path[->] (treatment) edge [loop left, distance=1cm, line width=<<transition_thickness['treatment']['treatment']>>] node {<<transitions['treatment']['treatment']>>} (treatment);
    \path[->] (eol) edge [loop left, distance=1cm, line width=<<transition_thickness['end_of_life']['end_of_life']>>] node {<<transitions['end_of_life']['end_of_life']>>} (eol);
    \path[->] (noed) edge [loop left, distance=1cm, line width=<<transition_thickness['cured']['cured']>>] node {<<transitions['cured']['cured']>>} (noed);
    \path [line, line width=<<transition_thickness['pretreatment']['treatment']>>] (pre) -- node {<<transitions['pretreatment']['treatment']>>} (treatment);
    \path [line, line width=<<transition_thickness['treatment']['end_of_life']>>] (treatment) -- node {<<transitions['treatment']['end_of_life']>>} (eol);
    \path [line, line width=<<transition_thickness['treatment']['cured']>>] (treatment) -- node {<<transitions['treatment']['cured']>>} (noed);
    \path [line, line width=<<transition_thickness['cured']['treatment']>>] (noed) |- node[right] {<<transitions['cured']['treatment']>>} (treatment);
"""
tikz_template = Template(tikz_template_string, variable_start_string="<<", variable_end_string=">>")

In [28]:
print(tikz_template.render(initial=initial, 
                           final=final, 
                           transitions=transitions, 
                           transition_thickness=transition_thickness))


    % Place nodes
    \node[label={[align=left,anchor=south west,shift={(block.north west)},inner sep=2pt]{I 56.2\%\\ F 3.0\%}}] [block] (pre) {Pre-treatment};
    \node[label={[align=left,anchor=south west,shift={(block.north west)},inner sep=2pt]{I 43.3\%\\ F 27.1\%}}] [block, below of=pre, node distance=2.2cm] (treatment) {Acute treatment};
    \node[label={[align=left,anchor=south west,shift={(block.north west)},inner sep=2pt]{I 0.5\%\\ F 27.1\%}}] [block, below of=treatment, node distance=3cm] (eol) {End-of-life};
    \node[label={[align=left,anchor=south west,shift={(block.north west)},inner sep=2pt]{I 0.0\%\\ F 42.9\%}}] [block, right of=eol, node distance=4cm] (noed) {No evidence of disease};
    % Draw edges
    \path[->] (pre) edge [loop left, distance=1cm, line width=0.49mm] node {77.6\%} (pre);
    \path[->] (treatment) edge [loop left, distance=1cm, line width=0.59mm] node {97.3\%} (treatment);
    \path[->] (eol) edge [loop left, distance=1cm, line width=0.60mm] node {10