In [1]:
push!(LOAD_PATH, normpath(@__DIR__, "../../", "src/models"));
push!(LOAD_PATH, normpath(@__DIR__, "../../", "src/processing"));
ENV["COLUMNS"] = 200;

In [2]:
using Dates
using JuMP
using CSV
using DataFrames
using LinearAlgebra

In [3]:
using BedsData
using ForecastData
using GeographicData

In [4]:
using PatientAllocation

In [5]:
states = ["CT", "DE", "MA", "MD", "ME", "NH", "NJ", "NY", "PA", "RI"]

start_date = Date(2020, 5, 10)
end_date   = Date(2020, 6, 1)

pct_beds_available = 0.25
travel_threshold_hours = 4.0
hospitalized_days = 14;

In [6]:
N = length(states);
T = (end_date - start_date).value + 1;

In [7]:
forecast_admitted = forecast(
    states, start_date, end_date,
    level=:state,
    source=:columbia,
    forecast_type=:admitted,
    patient_type=:regular,
    bound_type=:mean,
);

In [8]:
forecast_initial = forecast(
    states, start_date-Dates.Day(1), start_date-Dates.Day(1),
    level=:state,
    source=:columbia,
    forecast_type=:active,
    patient_type=:regular,
    bound_type=:mean,
    hospitalized_days=hospitalized_days
)[:];

In [9]:
forecast_discharged = forecast(
    states, start_date-Dates.Day(hospitalized_days), start_date-Dates.Day(1),
    level=:state,
    source=:columbia,
    forecast_type=:admitted,
    patient_type=:regular,
    bound_type=:mean,
)
forecast_discharged = hcat(forecast_discharged, zeros(Float32, N, T - hospitalized_days));

In [10]:
beds = n_beds(states, bed_type=:all, pct_beds_available=pct_beds_available);
adj = adjacencies(states, level=:state, source=:google, threshold=travel_threshold_hours);

In [11]:
model = patient_allocation(
    beds,
    forecast_initial,
    forecast_admitted,
    forecast_discharged,
    adj,
    hospitalized_days=hospitalized_days,
    send_new_only=true,
    sendrecieve_switch_time=3,
    min_send_amt=10,
    smoothness_penalty=0.001,
    setup_cost=0,
    sent_penalty=0,
    verbose=true
)
sent = value.(model[:sent])
println("termination status: ", termination_status(model))
println("solve time: ", round(solve_time(model), digits=3), "s")
println("objective function value: ", round(objective_value(model), digits=3))

Academic license - for non-commercial use only
Academic license - for non-commercial use only
Gurobi Optimizer version 9.0.1 build v9.0.1rc0 (mac64)
Optimize a model with 6014 rows, 5610 columns and 130782 nonzeros
Model fingerprint: 0x8791dcd4
Model has 440 SOS constraints
Variable types: 3310 continuous, 0 integer (0 binary)
Semi-Variable types: 2300 continuous, 0 integer
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [1e-03, 1e+00]
  Bounds range     [1e+01, 1e+01]
  RHS range        [5e-01, 3e+04]
Presolve removed 2639 rows and 2275 columns
Presolve time: 0.17s
Presolved: 5767 rows, 4531 columns, 62747 nonzeros
Presolved model has 385 SOS constraint(s)
Variable types: 3324 continuous, 1207 integer (1207 binary)
Found heuristic solution: objective 1093959.7500

Root relaxation: objective 9.608415e+05, 3026 iterations, 0.13 seconds

    Nodes    |    Current Node    |     Objective Bounds      |     Work
 Expl Unexpl |  Obj  Depth IntInf | Incumbent    B

In [12]:
overflow_per_day = (i,t) -> sum(max.(0,
    forecast_initial[i] - sum(forecast_discharged[i,1:min(t,hospitalized_days)])
    + sum(forecast_admitted[i,max(1,t-hospitalized_days):t])
    - sum(sent[i,:,1:t-1])
    + sum(sent[:,i,max(1,t-hospitalized_days):t])
    - beds[i])
)
overflow = i -> sum(overflow_per_day(i,t) for t=1:size(sent,3));

In [13]:
println("Total overflow: ", sum(overflow.(1:length(states))))

Total overflow: 962115.75


In [14]:
summary = DataFrame(
    state=states,
    total_sent=sum(sent, dims=[2,3])[:],
    total_received=sum(sent, dims=[1,3])[:],
    overflow=overflow.(1:length(states)),
)

Unnamed: 0_level_0,state,total_sent,total_received,overflow
Unnamed: 0_level_1,String,Float64,Float64,Float64
1,CT,643.0,265.788,9501.0
2,DE,544.25,14632.4,230255.0
3,MA,8021.88,0.0,163349.0
4,MD,3073.0,0.0,77064.8
5,ME,0.0,8328.13,105240.0
6,NH,0.0,327.3,104.0
7,NJ,7942.35,0.0,176694.0
8,NY,5919.0,850.654,163360.0
9,PA,733.75,3579.96,4268.32
10,RI,1107.0,0.0,32279.2


In [15]:
sent_matrix = DataFrame(sum(sent, dims=3)[:,:,1])
rename!(sent_matrix, Symbol.(states))
insertcols!(sent_matrix, 1, :state => states)

Unnamed: 0_level_0,state,CT,DE,MA,MD,ME,NH,NJ,NY,PA,RI
Unnamed: 0_level_1,String,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64
1,CT,0.0,0.0,0.0,0.0,643.0,0.0,0.0,0.0,0.0,0.0
2,DE,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,544.25,0.0
3,MA,265.788,0.0,0.0,0.0,6578.13,327.3,0.0,850.654,0.0,0.0
4,MD,0.0,2893.0,0.0,0.0,0.0,0.0,0.0,0.0,180.0,0.0
5,ME,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,NH,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,NJ,0.0,5086.64,0.0,0.0,0.0,0.0,0.0,0.0,2855.71,0.0
8,NY,0.0,5919.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,PA,0.0,733.75,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
10,RI,0.0,0.0,0.0,0.0,1107.0,0.0,0.0,0.0,0.0,0.0


In [16]:
sent_vis_matrix = sum(sent, dims=3)[:,:,1] + diagm(sum(max.(0, forecast_admitted), dims=2)[:] - sum(sent, dims=[2,3])[:])
sent_vis_matrix = DataFrame(sent_vis_matrix)
rename!(sent_vis_matrix, Symbol.(states));

In [17]:
total_patients = (i,t) -> (
    forecast_initial[i] - sum(forecast_discharged[i,1:min(t,hospitalized_days)])
    + sum(forecast_admitted[i,max(1,t-hospitalized_days):t])
    - sum(sent[i,:,1:t])
    + sum(sent[:,i,max(1,t-hospitalized_days):t])
);

In [18]:
outcomes = DataFrame()
for (i,s) in enumerate(states)
    single_state_outcome = DataFrame(
        state=fill(s, T),
        day=start_date .+ Dates.Day.(0:T-1),
        sent=sum(sent[i,:,:], dims=1)[:],
        received=sum(sent[:,i,:], dims=1)[:],
        new_patients=forecast_admitted[i,:],
        total_patients=[total_patients(i,t) for t in 1:T],
        capacity=fill(beds[i], T),
        overflow=[overflow_per_day(i,t) for t in 1:T],
        sent_to=[sum(sent[i,:,t])>0 ? collect(zip(states[sent[i,:,t] .> 0], sent[i,sent[i,:,t].>0,t])) : "[]" for t in 1:T],
        sent_from=[sum(sent[:,i,t])>0 ? collect(zip(states[sent[:,i,t] .> 0], sent[sent[:,i,t].>0,i,t])) : "[]" for t in 1:T],
    )
    outcomes = vcat(outcomes, single_state_outcome)
end
# CSV.write("patient_allocation_results.csv", outcomes)
println("First day:")
filter(row -> row.day == start_date, outcomes)

First day:


Unnamed: 0_level_0,state,day,sent,received,new_patients,total_patients,capacity,overflow,sent_to,sent_from
Unnamed: 0_level_1,String,Date,Float64,Float64,Float32,Float64,Float32,Float64,Any,Any
1,CT,2020-05-10,231.0,0.0,231.0,3350.0,1840.25,1740.75,"[(""ME"", 231.0)]",[]
2,DE,2020-05-10,0.0,5119.75,102.0,6261.75,506.75,5755.0,[],"[(""MD"", 775.0), (""NJ"", 1645.0), (""NY"", 2062.0), (""PA"", 637.75)]"
3,MA,2020-05-10,1539.0,0.0,1539.0,20528.0,3997.75,18069.2,"[(""ME"", 1539.0)]",[]
4,MD,2020-05-10,775.0,0.0,775.0,9392.0,2228.75,7938.25,"[(""DE"", 775.0)]",[]
5,ME,2020-05-10,0.0,2073.0,1.0,2128.0,668.0,1460.0,[],"[(""CT"", 231.0), (""MA"", 1539.0), (""RI"", 303.0)]"
6,NH,2020-05-10,0.0,0.0,37.0,594.0,584.5,9.5,[],[]
7,NJ,2020-05-10,1645.0,0.0,1645.0,20734.0,4553.25,17825.8,"[(""DE"", 1645.0)]",[]
8,NY,2020-05-10,2062.0,0.0,2062.0,26920.0,10406.2,18575.8,"[(""DE"", 2062.0)]",[]
9,PA,2020-05-10,637.75,0.0,647.0,7872.25,7980.25,529.75,"[(""DE"", 637.75)]",[]
10,RI,2020-05-10,303.0,0.0,303.0,3665.0,674.25,3293.75,"[(""ME"", 303.0)]",[]


In [19]:
s = "NJ"
filter(row -> row.state == s, outcomes)

Unnamed: 0_level_0,state,day,sent,received,new_patients,total_patients,capacity,overflow,sent_to,sent_from
Unnamed: 0_level_1,String,Date,Float64,Float64,Float32,Float64,Float32,Float64,Any,Any
1,NJ,2020-05-10,1645.0,0.0,1645.0,20734.0,4553.25,17825.8,"[(""DE"", 1645.0)]",[]
2,NJ,2020-05-11,1556.0,0.0,1556.0,19456.0,4553.25,16458.8,"[(""DE"", 1556.0)]",[]
3,NJ,2020-05-12,1490.0,0.0,1490.0,18131.0,4553.25,15067.8,"[(""DE"", 1490.0)]",[]
4,NJ,2020-05-13,395.644,0.0,1406.0,17737.4,4553.25,13579.8,"[(""DE"", 395.644)]",[]
5,NJ,2020-05-14,0.0,0.0,1348.0,17620.4,4553.25,13067.1,[],[]
6,NJ,2020-05-15,158.65,0.0,1280.0,17195.7,4553.25,12801.1,"[(""PA"", 158.65)]",[]
7,NJ,2020-05-16,158.65,0.0,1168.0,16618.1,4553.25,12223.5,"[(""PA"", 158.65)]",[]
8,NJ,2020-05-17,158.65,0.0,1074.0,15876.4,4553.25,11481.8,"[(""PA"", 158.65)]",[]
9,NJ,2020-05-18,158.65,0.0,1008.0,14996.8,4553.25,10602.2,"[(""PA"", 158.65)]",[]
10,NJ,2020-05-19,158.65,0.0,936.0,14039.1,4553.25,9644.5,"[(""PA"", 158.65)]",[]


In [20]:
println("Sent to:")
Dict(states[i] => states[row] for (i,row) in enumerate(eachrow(sum(sent, dims=3)[:,:,1] .> 0)))

Sent to:


Dict{String,Array{String,1}} with 10 entries:
  "NH" => String[]
  "CT" => ["ME"]
  "RI" => ["ME"]
  "MA" => ["CT", "ME", "NH", "NY"]
  "ME" => String[]
  "NY" => ["DE"]
  "NJ" => ["DE", "PA"]
  "DE" => ["PA"]
  "PA" => ["DE"]
  "MD" => ["DE", "PA"]