In [2]:
'''
Use the files in examples/ to generate changes that will be required
'''

from pathlib import Path
from openai import OpenAI
from IPython.display import display, Markdown
from starsim import llm
MODEL = "gpt-4o-mini"

client = OpenAI()

In [24]:
diff_file = Path("../zombiesim.diff").resolve()
git_diffs = llm.GitDiffSuggestions(diff_file.as_posix(), include_patterns=["starsim/diseases/*.py"])
git_diffs.summarize()

Number of files found: 8
Number of hunks: 47
Names of files found: ['starsim/diseases/cholera.py', 'starsim/diseases/ebola.py', 'starsim/diseases/gonorrhea.py', 'starsim/diseases/hiv.py', 'starsim/diseases/measles.py', 'starsim/diseases/ncd.py', 'starsim/diseases/sir.py', 'starsim/diseases/syphilis.py']


In [34]:
prompt_system_template = '''
Your task is to look for at git diffs for code relating to how the python starsim module (imported as ss) is used. Summarize what changes might 
affect other code depending on the starsim package. Do not include suggestions for changes unrelated to starsim. 
For the suggestion provide your answers as json object with the following format:

{"previous": The code before the change, "new": The code after the change}
{"previous": The code before the change, "new": The code after the change}
{"previous": The code before the change, "new": The code after the change}

If there is no change needed, provide an empty object.
'''

prompt_user_template = '''
diff hunk:

{}

Identified changes:
'''

In [None]:
file_name = 'starsim/diseases/sir.py'
results = []
for diff in git_diffs.diffs:
    if file_name == diff["file"]:
        print("Number of hunks found: ", len(diff["hunks"]))
        for i, hunk in enumerate(diff["hunks"]):                
                system_content = prompt_system_template
                user_content = prompt_user_template.format(hunk)
                print(user_content)
                response = llm.complete(client, system_content=system_content, user_content=user_content, model=MODEL)
                results.append(response.choices[0].message.content)
                display(Markdown(response.choices[0].message.content))


Number of hunks found:  6

diff hunk:

@@ -3,7 +3,8 @@ Define SIR and SIS disease modules
 """
 
 import numpy as np
-import matplotlib.pyplot as pl
+import sciris as sc
+import matplotlib.pyplot as plt
 import starsim as ss
 
 


Identified changes:



```json
{"previous": "import matplotlib.pyplot as pl", "new": "import matplotlib.pyplot as plt"}
{"previous": "", "new": "import sciris as sc"}
```


diff hunk:

@@ -19,22 +20,27 @@ class SIR(ss.Infection):
     """
     def __init__(self, pars=None, **kwargs):
         super().__init__()
-        self.default_pars(
-            beta = 0.1,
+        self.define_pars(
+            beta = ss.beta(0.1),
             init_prev = ss.bernoulli(p=0.01),
-            dur_inf = ss.lognorm_ex(mean=6),
+            dur_inf = ss.lognorm_ex(mean=ss.dur(6)),
             p_death = ss.bernoulli(p=0.01),
         )
         self.update_pars(pars, **kwargs)
 
-        self.add_states(
-            ss.BoolArr('recovered'),
-            ss.FloatArr('ti_recovered'),
-            ss.FloatArr('ti_dead'),
+        self.define_states(
+            ss.State('susceptible', default=True, label='Susceptible'),
+            ss.State('infected', label='Infectious'),
+            ss.State('recovered', label='Recovered'),
+            ss.FloatArr('ti_infected', label='Time of infection'),
+            ss.FloatArr('ti_recovered', label='Time of recovery'),
+      

```json
{"previous": "self.default_pars(\n            beta = 0.1,\n            init_prev = ss.bernoulli(p=0.01),\n            dur_inf = ss.lognorm_ex(mean=6),\n            p_death = ss.bernoulli(p=0.01),\n        )","new": "self.define_pars(\n            beta = ss.beta(0.1),\n            init_prev = ss.bernoulli(p=0.01),\n            dur_inf = ss.lognorm_ex(mean=ss.dur(6)),\n            p_death = ss.bernoulli(p=0.01),\n        )"}
{"previous": "self.add_states(\n            ss.BoolArr('recovered'),\n            ss.FloatArr('ti_recovered'),\n            ss.FloatArr('ti_dead'),\n        )","new": "self.define_states(\n            ss.State('susceptible', default=True, label='Susceptible'),\n            ss.State('infected', label='Infectious'),\n            ss.State('recovered', label='Recovered'),\n            ss.FloatArr('ti_infected', label='Time of infection'),\n            ss.FloatArr('ti_recovered', label='Time of recovery'),\n            ss.FloatArr('ti_dead', label='Time of death'),\n            ss.FloatArr('rel_sus', default=1.0, label='Relative susceptibility'),\n            ss.FloatArr('rel_trans', default=1.0, label='Relative transmission'),\n        )"}
{"previous": "def update_pre(self):","new": "def step_state(self):"}
```


diff hunk:

@@ -47,11 +53,10 @@ class SIR(ss.Infection):
             sim.people.request_death(deaths)
         return
 
-    def set_prognoses(self, uids, source_uids=None):
+    def set_prognoses(self, uids, sources=None):
         """ Set prognoses """
-        super().set_prognoses(uids, source_uids)
-        ti = self.sim.ti
-        dt = self.sim.dt
+        super().set_prognoses(uids, sources)
+        ti = self.t.ti
         self.susceptible[uids] = False
         self.infected[uids] = True
         self.ti_infected[uids] = ti


Identified changes:



```json
{"previous": "def set_prognoses(self, uids, source_uids=None):", "new": "def set_prognoses(self, uids, sources=None):"}
{"previous": "super().set_prognoses(uids, source_uids)", "new": "super().set_prognoses(uids, sources)"}
{"previous": "dt = self.sim.dt", "new": "ti = self.t.ti"}
```


diff hunk:

@@ -66,25 +71,32 @@ class SIR(ss.Infection):
         will_die = p.p_death.rvs(uids)
         dead_uids = uids[will_die]
         rec_uids = uids[~will_die]
-        self.ti_dead[dead_uids] = ti + dur_inf[will_die] / dt # Consider rand round, but not CRN safe
-        self.ti_recovered[rec_uids] = ti + dur_inf[~will_die] / dt
+        self.ti_dead[dead_uids] = ti + dur_inf[will_die] # Consider rand round, but not CRN safe
+        self.ti_recovered[rec_uids] = ti + dur_inf[~will_die]
         return
 
-    def update_death(self, uids):
+    def step_die(self, uids):
         """ Reset infected/recovered flags for dead agents """
         self.susceptible[uids] = False
         self.infected[uids] = False
         self.recovered[uids] = False
         return
 
-    def plot(self):
+    def plot(self, **kwargs):
         """ Default plot for SIR model """
-        fig = pl.figure()
-        for rkey in ['susceptible', 'infected', 'recovered']:
-            pl.plot(self.resul

```json
{"previous": "self.ti_dead[dead_uids] = ti + dur_inf[will_die] / dt # Consider rand round, but not CRN safe\nself.ti_recovered[rec_uids] = ti + dur_inf[~will_die] / dt","new": "self.ti_dead[dead_uids] = ti + dur_inf[will_die] # Consider rand round, but not CRN safe\nself.ti_recovered[rec_uids] = ti + dur_inf[~will_die]"}
{"previous": "def update_death(self, uids):","new": "def step_die(self, uids):"}
{"previous": "def plot(self):","new": "def plot(self, **kwargs):"}
{"previous": "fig = pl.figure()","new": "fig = plt.figure()"}
{"previous": "for rkey in ['susceptible', 'infected', 'recovered']:","new": "for rkey in ['n_susceptible', 'n_infected', 'n_recovered']:"}
{"previous": "pl.legend()","new": "plt.legend(frameon=False)"}
{"previous": "return fig","new": "return ss.return_fig(fig)"}
```


diff hunk:

@@ -96,70 +108,79 @@ class SIS(ss.Infection):
     """
     def __init__(self, pars=None, *args, **kwargs):
         super().__init__()
-        self.default_pars(
-            beta = 0.05,
+        self.define_pars(
+            beta = ss.beta(0.05),
             init_prev = ss.bernoulli(p=0.01),
-            dur_inf = ss.lognorm_ex(mean=10),
-            waning = 0.05,
+            dur_inf = ss.lognorm_ex(mean=ss.dur(10)),
+            waning = ss.rate(0.05),
             imm_boost = 1.0,
         )
         self.update_pars(pars=pars, *args, **kwargs)
 
-        self.add_states(
+        self.define_states(
             ss.FloatArr('ti_recovered'),
             ss.FloatArr('immunity', default=0.0),
         )
         return
 
-    def update_pre(self):
+    def step_state(self):
         """ Progress infectious -> recovered """
-        recovered = (self.infected & (self.ti_recovered <= self.sim.ti)).uids
+        recovered = (self.infected & (self.ti_recovered <= self

```json
{"previous": "self.default_pars(\n            beta = 0.05,\n            init_prev = ss.bernoulli(p=0.01),\n            dur_inf = ss.lognorm_ex(mean=10),\n            waning = 0.05,\n            imm_boost = 1.0,\n        )", "new": "self.define_pars(\n            beta = ss.beta(0.05),\n            init_prev = ss.bernoulli(p=0.01),\n            dur_inf = ss.lognorm_ex(mean=ss.dur(10)),\n            waning = ss.rate(0.05),\n            imm_boost = 1.0,\n        )"}
{"previous": "self.add_states(\n             ss.FloatArr('ti_recovered'),\n             ss.FloatArr('immunity', default=0.0),\n         )", "new": "self.define_states(\n             ss.FloatArr('ti_recovered'),\n             ss.FloatArr('immunity', default=0.0),\n         )"}
{"previous": "self.ti_infected[uids] = self.sim.ti", "new": "self.ti_infected[uids] = self.ti"}
{"previous": "self.ti_recovered[uids] = self.sim.ti + dur_inf / self.sim.dt", "new": "self.ti_recovered[uids] = self.ti + dur_inf"}
{"previous": "self.results += ss.Result(self.name, 'rel_sus', self.sim.npts, dtype=float)", "new": "self.define_results(\n            ss.Result('rel_sus', dtype=float, label='Relative susceptibility')\n        )"}
{"previous": "self.results['rel_sus'][self.sim.ti] = self.rel_sus.mean()", "new": "self.results['rel_sus'][self.ti] = self.rel_sus.mean()"}
{"previous": "fig = pl.figure()", "new": "fig = plt.figure()"}
{"previous": "for rkey in ['susceptible', 'infected']:", "new": "for rkey in ['n_susceptible', 'n_infected']:"}
{"previous": "return fig", "new": "return ss.return_fig(fig)"}
```


diff hunk:

@@ -169,27 +190,27 @@ __all__ += ['sir_vaccine']
 class sir_vaccine(ss.Vx):
     """
     Create a vaccine product that affects the probability of infection.
-    
-    The vaccine can be either "leaky", in which everyone who receives the vaccine 
-    receives the same amount of protection (specified by the efficacy parameter) 
+
+    The vaccine can be either "leaky", in which everyone who receives the vaccine
+    receives the same amount of protection (specified by the efficacy parameter)
     each time they are exposed to an infection. The alternative (leaky=False) is
     that the efficacy is the probability that the vaccine "takes", in which case
     that person is 100% protected (and the remaining people are 0% protected).
-    
+
     Args:
         efficacy (float): efficacy of the vaccine (0<=efficacy<=1)
         leaky (bool): see above
     """
     def __init__(self, pars=None, *args, **kwargs):
         super().__init__()
-        self.default_pars(
+      

```json
{"previous": "self.default_pars(\n            efficacy = 0.9,\n            leaky = True\n        )", "new": "self.define_pars(\n            efficacy = 0.9,\n            leaky = True\n        )"}
```

In [64]:
import re
import json

def process_results(results):
    processed = []
    pattern = r'\n(.+?)\n(?=[^\n]*$)'
    for r in results:
        if r:
            match = re.search(pattern, r, re.DOTALL)
            if match:
                s = match.group(0).splitlines()
                for e in s:
                    if e:
                        processed.append(json.loads(e))
    return processed

processed = process_results(results)
print(len(processed))
processed

25


[{'previous': 'import matplotlib.pyplot as pl',
  'new': 'import matplotlib.pyplot as plt'},
 {'previous': '', 'new': 'import sciris as sc'},
 {'previous': 'self.default_pars(\n            beta = 0.1,\n            init_prev = ss.bernoulli(p=0.01),\n            dur_inf = ss.lognorm_ex(mean=6),\n            p_death = ss.bernoulli(p=0.01),\n        )',
  'new': 'self.define_pars(\n            beta = ss.beta(0.1),\n            init_prev = ss.bernoulli(p=0.01),\n            dur_inf = ss.lognorm_ex(mean=ss.dur(6)),\n            p_death = ss.bernoulli(p=0.01),\n        )'},
 {'previous': "self.add_states(\n            ss.BoolArr('recovered'),\n            ss.FloatArr('ti_recovered'),\n            ss.FloatArr('ti_dead'),\n        )",
  'new': "self.define_states(\n            ss.State('susceptible', default=True, label='Susceptible'),\n            ss.State('infected', label='Infectious'),\n            ss.State('recovered', label='Recovered'),\n            ss.FloatArr('ti_infected', label='Time

In [77]:
s = ',\n'.join([str(p) for p in processed])
print(s)

{'previous': 'import matplotlib.pyplot as pl', 'new': 'import matplotlib.pyplot as plt'},
{'previous': '', 'new': 'import sciris as sc'},
{'previous': 'self.default_pars(\n            beta = 0.1,\n            init_prev = ss.bernoulli(p=0.01),\n            dur_inf = ss.lognorm_ex(mean=6),\n            p_death = ss.bernoulli(p=0.01),\n        )', 'new': 'self.define_pars(\n            beta = ss.beta(0.1),\n            init_prev = ss.bernoulli(p=0.01),\n            dur_inf = ss.lognorm_ex(mean=ss.dur(6)),\n            p_death = ss.bernoulli(p=0.01),\n        )'},
{'previous': "self.add_states(\n            ss.BoolArr('recovered'),\n            ss.FloatArr('ti_recovered'),\n            ss.FloatArr('ti_dead'),\n        )", 'new': "self.define_states(\n            ss.State('susceptible', default=True, label='Susceptible'),\n            ss.State('infected', label='Infectious'),\n            ss.State('recovered', label='Recovered'),\n            ss.FloatArr('ti_infected', label='Time of infect

In [78]:
# write s to a file
with open("zombie_A_changes.jsonl", "w") as f:
    f.write(s)

## Appply changes to zombie.py

In [None]:
file = Path("../zombie.py")
with open(file, "r") as f:
    code = f.read()

prompt = '''
Given this list of changes, update the code below to reflect the changes.

Changes: 
{}

Code:
{}
'''

response = llm.complete(client, user_content=prompt.format(s, code), model=MODEL)

In [None]:
print(response.choices[0].message.content)

Here's the updated code reflecting the changes you provided:

```python
import matplotlib.pyplot as plt
import sciris as sc
import numpy as np

class Zombie(ss.SIR):
    """ Extent the base SIR class to represent Zombies! """
    def __init__(self, pars=None, **kwargs):
        super().__init__()

        self.define_pars(
            inherit=True,  # Inherit from SIR defaults
            dur_inf=ss.constant(v=1000),  # Once a zombie, always a zombie! Units are years.

            p_fast=ss.bernoulli(p=0.10),  # Probability of being fast
            dur_fast=ss.constant(v=1000),  # Duration of fast before becoming slow
            p_symptomatic=ss.bernoulli(p=1.0),  # Probability of symptoms
            p_death_on_zombie_infection=ss.bernoulli(p=0.25),  # Probability of death at time of infection

            p_death=ss.bernoulli(p=1),  # All zombies die instead of recovering
        )
        self.update_pars(pars, **kwargs)

        self.define_states(
            ss.BoolArr('fast', 