In [1]:
push!(LOAD_PATH, normpath(@__DIR__, "../../", "src/models"));
push!(LOAD_PATH, normpath(@__DIR__, "../../", "src/processing"));
ENV["COLUMNS"] = 200;

In [2]:
using Dates
using JuMP
using CSV
using DataFrames
using LinearAlgebra

In [3]:
using BedsData
using ForecastData
using GeographicData

In [4]:
using PatientAllocation

In [5]:
states = ["CT", "DE", "MA", "MD", "ME", "NH", "NJ", "NY", "PA", "RI"]

start_date = Date(2020, 5, 1)
end_date   = Date(2020, 5, 30)

pct_beds_available = 0.75
travel_threshold_hours = 4.0
hospitalized_days = 14;

In [6]:
N = length(states);
T = (end_date - start_date).value + 1;

In [7]:
forecast_admitted = forecast(
    states, start_date, end_date,
    level=:state,
    source=:ihme,
    forecast_type=:admitted,
    patient_type=:icu,
    bound_type=:mean,
);

In [8]:
forecast_initial = forecast(
    states, start_date-Dates.Day(1), start_date-Dates.Day(1),
    level=:state,
    source=:ihme,
    forecast_type=:active,
    patient_type=:icu,
    bound_type=:mean,
)[:];

In [9]:
forecast_discharged = forecast(
    states, start_date-Dates.Day(hospitalized_days), start_date-Dates.Day(1),
    level=:state,
    source=:ihme,
    forecast_type=:admitted,
    patient_type=:icu,
    bound_type=:mean,
)
forecast_discharged = hcat(forecast_discharged, zeros(Float32, N, T - hospitalized_days));

In [10]:
beds = n_beds(states, bed_type=:icu, pct_beds_available=pct_beds_available);
adj = adjacencies(states, level=:state, source=:google, threshold=travel_threshold_hours);

In [11]:
model = patient_allocation(
    beds,
    forecast_initial,
    forecast_admitted,
    forecast_discharged,
    adj,
    hospitalized_days=hospitalized_days,
    send_new_only=true,
    sendrecieve_switch_time=3,
    min_send_amt=10,
    smoothness_penalty=0.001,
    setup_cost=0,
    sent_penalty=0,
    verbose=true
)
sent = value.(model[:sent])
println("termination status: ", termination_status(model))
println("solve time: ", round(solve_time(model), digits=3), "s")
println("objective function value: ", round(objective_value(model), digits=3))

Academic license - for non-commercial use only
Academic license - for non-commercial use only
Gurobi Optimizer version 9.0.1 build v9.0.1rc0 (mac64)
Optimize a model with 7604 rows, 7360 columns and 123280 nonzeros
Model fingerprint: 0xe7b03a57
Model has 580 SOS constraints
Variable types: 4360 continuous, 0 integer (0 binary)
Semi-Variable types: 3000 continuous, 0 integer
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [1e-03, 1e+00]
  Bounds range     [1e+01, 1e+01]
  RHS range        [1e-01, 5e+03]
Presolve removed 3747 rows and 3663 columns
Presolve time: 0.25s
Presolved: 6619 rows, 5078 columns, 63458 nonzeros
Presolved model has 241 SOS constraint(s)
Variable types: 3487 continuous, 1591 integer (1591 binary)
Found heuristic solution: objective 10931.028641

Root relaxation: objective 2.365817e+03, 972 iterations, 0.06 seconds

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    Be

In [12]:
overflow_per_day = (i,t) -> sum(max.(0,
    forecast_initial[i] - sum(forecast_discharged[i,1:min(t,hospitalized_days)])
    + sum(forecast_admitted[i,max(1,t-hospitalized_days):t])
    - sum(sent[i,:,1:t-1])
    + sum(sent[:,i,max(1,t-hospitalized_days):t])
    - beds[i])
)
overflow = i -> sum(overflow_per_day(i,t) for t=1:size(sent,3));

In [13]:
println("Total overflow: ", sum(overflow.(1:length(states))))

Total overflow: 2365.36865234375


In [14]:
summary = DataFrame(
    state=states,
    total_sent=sum(sent, dims=[2,3])[:],
    total_received=sum(sent, dims=[1,3])[:],
    overflow=overflow.(1:length(states)),
)

Unnamed: 0_level_0,state,total_sent,total_received,overflow
Unnamed: 0_level_1,String,Float64,Float64,Float64
1,CT,300.0,0.0,45.8721
2,DE,0.0,0.0,0.0
3,MA,0.0,0.0,0.0
4,MD,0.0,0.0,0.0
5,ME,0.0,300.0,0.0
6,NH,0.0,0.0,0.0
7,NJ,4318.48,0.0,2319.5
8,NY,0.0,3168.71,0.0
9,PA,0.0,1149.77,0.0
10,RI,0.0,0.0,0.0


In [15]:
sent_matrix = DataFrame(sum(sent, dims=3)[:,:,1])
rename!(sent_matrix, Symbol.(states))
insertcols!(sent_matrix, 1, :state => states)

Unnamed: 0_level_0,state,CT,DE,MA,MD,ME,NH,NJ,NY,PA,RI
Unnamed: 0_level_1,String,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64
1,CT,0.0,0.0,0.0,0.0,300.0,0.0,0.0,0.0,0.0,0.0
2,DE,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,MA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,MD,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,ME,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,NH,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,NJ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3168.71,1149.77,0.0
8,NY,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,PA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
10,RI,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [16]:
sent_vis_matrix = sum(sent, dims=3)[:,:,1] + diagm(sum(max.(0, forecast_admitted), dims=2)[:] - sum(sent, dims=[2,3])[:])
sent_vis_matrix = DataFrame(sent_vis_matrix)
rename!(sent_vis_matrix, Symbol.(states));

In [17]:
total_patients = (i,t) -> (
    forecast_initial[i] - sum(forecast_discharged[i,1:min(t,hospitalized_days)])
    + sum(forecast_admitted[i,max(1,t-hospitalized_days):t])
    - sum(sent[i,:,1:t])
    + sum(sent[:,i,max(1,t-hospitalized_days):t])
);

In [18]:
outcomes = DataFrame()
for (i,s) in enumerate(states)
    single_state_outcome = DataFrame(
        state=fill(s, T),
        day=start_date .+ Dates.Day.(0:T-1),
        sent=sum(sent[i,:,:], dims=1)[:],
        received=sum(sent[:,i,:], dims=1)[:],
        new_patients=forecast_admitted[i,:],
        total_patients=[total_patients(i,t) for t in 1:T],
        capacity=fill(beds[i], T),
        overflow=[overflow_per_day(i,t) for t in 1:T],
        sent_to=[sum(sent[i,:,t])>0 ? collect(zip(states[sent[i,:,t] .> 0], sent[i,sent[i,:,t].>0,t])) : "[]" for t in 1:T],
        sent_from=[sum(sent[:,i,t])>0 ? collect(zip(states[sent[:,i,t] .> 0], sent[sent[:,i,t].>0,i,t])) : "[]" for t in 1:T],
    )
    outcomes = vcat(outcomes, single_state_outcome)
end
# CSV.write("patient_allocation_results.csv", outcomes)
println("First day:")
filter(row -> row.day == start_date, outcomes)

First day:


Unnamed: 0_level_0,state,day,sent,received,new_patients,total_patients,capacity,overflow,sent_to,sent_from
Unnamed: 0_level_1,String,Date,Float64,Float64,Float32,Float64,Float32,Float64,Any,Any
1,CT,2020-05-01,10.0,0.0,97.655,759.622,723.75,45.8721,"[(""ME"", 10.0)]",[]
2,DE,2020-05-01,0.0,0.0,10.3184,76.7873,195.75,0.0,[],[]
3,MA,2020-05-01,0.0,0.0,145.0,866.003,1490.25,0.0,[],[]
4,MD,2020-05-01,0.0,0.0,81.6351,681.535,967.5,0.0,[],[]
5,ME,2020-05-01,0.0,10.0,1.3129,21.2305,207.0,0.0,[],"[(""CT"", 10.0)]"
6,NH,2020-05-01,0.0,0.0,8.67527,47.8704,206.25,0.0,[],[]
7,NJ,2020-05-01,285.255,0.0,285.255,2104.52,1317.0,1072.77,"[(""NY"", 241.499), (""PA"", 43.7558)]",[]
8,NY,2020-05-01,0.0,241.499,363.467,2445.33,3225.75,0.0,[],"[(""NJ"", 241.499)]"
9,PA,2020-05-01,0.0,43.7558,202.97,1381.32,2743.5,0.0,[],"[(""NJ"", 43.7558)]"
10,RI,2020-05-01,0.0,0.0,21.8264,146.6,281.25,0.0,[],[]


In [19]:
s = "NJ"
filter(row -> row.state == s, outcomes)

Unnamed: 0_level_0,state,day,sent,received,new_patients,total_patients,capacity,overflow,sent_to,sent_from
Unnamed: 0_level_1,String,Date,Float64,Float64,Float32,Float64,Float32,Float64,Any,Any
1,NJ,2020-05-01,285.255,0.0,285.255,2104.52,1317.0,1072.77,"[(""NY"", 241.499), (""PA"", 43.7558)]",[]
2,NJ,2020-05-02,279.638,0.0,279.638,1772.6,1317.0,735.233,"[(""NY"", 241.499), (""PA"", 38.1383)]",[]
3,NJ,2020-05-03,276.052,0.0,276.052,1452.46,1317.0,411.515,"[(""NY"", 237.914), (""PA"", 38.1383)]",[]
4,NJ,2020-05-04,128.798,0.0,276.384,1288.18,1317.0,99.9754,"[(""NY"", 90.6592), (""PA"", 38.1383)]",[]
5,NJ,2020-05-05,128.798,0.0,281.698,1132.6,1317.0,0.0,"[(""NY"", 90.6592), (""PA"", 38.1383)]",[]
6,NJ,2020-05-06,128.798,0.0,281.198,971.69,1317.0,0.0,"[(""NY"", 90.6592), (""PA"", 38.1383)]",[]
7,NJ,2020-05-07,128.798,0.0,265.053,782.854,1317.0,0.0,"[(""NY"", 90.6592), (""PA"", 38.1383)]",[]
8,NJ,2020-05-08,128.798,0.0,261.34,577.339,1317.0,0.0,"[(""NY"", 90.6592), (""PA"", 38.1383)]",[]
9,NJ,2020-05-09,128.798,0.0,254.691,359.216,1317.0,0.0,"[(""NY"", 90.6592), (""PA"", 38.1383)]",[]
10,NJ,2020-05-10,128.798,0.0,242.958,130.335,1317.0,0.0,"[(""NY"", 90.6592), (""PA"", 38.1383)]",[]


In [20]:
println("Sent to:")
Dict(states[i] => states[row] for (i,row) in enumerate(eachrow(sum(sent, dims=3)[:,:,1] .> 0)))

Sent to:


Dict{String,Array{String,1}} with 10 entries:
  "NH" => String[]
  "CT" => ["ME"]
  "RI" => String[]
  "MA" => String[]
  "ME" => String[]
  "NY" => String[]
  "NJ" => ["NY", "PA"]
  "DE" => String[]
  "PA" => String[]
  "MD" => String[]