In [1]:
using CSV
using Dates
using DataFrames

In [2]:
using JuMP
using Gurobi

In [3]:
function patient_allocation(beds::Array{Float32,1}, patients::Array{Float32,2})
    N, T = size(patients)
    @assert(size(beds, 1) == N)

    model = Model(Gurobi.Optimizer)
    @variable(model, sent[1:N,1:N,1:T] >= 0)
    @variable(model, dummy[1:N,1:T] >= 0)

    @objective(model, Min, sum(dummy))

    # sent <= current_patients
    @constraint(model, [t=1:T],
        sum(sent[:,:,t], dims=2) .<= sum(patients[:,1:t], dims=2)[:] .- sum(sent[:,:,1:t-1], dims=[2,3])[:] .+ sum(sent[:,:,1:t-1], dims=[1,3])
    )

    # can't send to self
    @constraint(model, [i=1:N],
        sent[i,i,:] .== 0
    )

    # dummy
    @constraint(model, [i=1:N,t=1:T],
        dummy[i,t] >= (sum(patients[i,1:t]) - sum(sent[i,:,1:t]) + sum(sent[:,i,1:t])) - beds[i]
    )

    optimize!(model)
    return model
end;

In [4]:
# select start and end dates
start_date = Date(2020, 4, 7)
end_date   = Date(2020, 5, 15)

# load the forecast data
forecast_data = CSV.read("../../data/forecasts/ihme_2020_04_12/forecast.csv", copycols=true)

# filter to US states
state_list = CSV.read("../../data/geography/state_names.csv", copycols=true)
filter!(row -> row.location_name in state_list.State, forecast_data)

# add state abbreviations
state_dict = Dict(state.State => state.Abbreviation for state in eachrow(state_list))
forecast_data.state = [state_dict[row.location_name] for row in eachrow(forecast_data)]

# sort
sort!(forecast_data, [:state, :date])

# compute net change
allbed_mean_net = Array{Float64,1}(undef, size(forecast_data,1))
forecast_start = Array{Float32,1}(undef, length(unique(forecast_data.state)))
groups = groupby(forecast_data, :state).groups
for i = 1:maximum(groups)
    mask = groups .== i
    rows = forecast_data[mask,:]
    allbed_mean_net[mask] = rows.allbed_mean - [0; rows.allbed_mean[1:end-1]]
    forecast_start[i] = sum(rows.allbed_mean[rows.date .< start_date])
end
insertcols!(forecast_data, 1, :allbed_mean_net => allbed_mean_net)

# filter by date
filter!(row -> start_date <= row.date <= end_date, forecast_data)

# group by state
forecast_data_loc = groupby(forecast_data, :state, sort=true)

# select forecast column
forecast = hcat([loc.allbed_mean_net[:] for loc in forecast_data_loc]...)'
forecast = [forecast_start Float32.(forecast)];

In [5]:
# load the beds data
beds_data = CSV.read("../../data/hospitals/hospital_locations.csv", copycols=true)

# filter
filter!(row -> row.BEDS > 0, beds_data)
filter!(row -> row.STATE in state_list.Abbreviation, beds_data)

# aggregate by state
beds_data = by(beds_data, :STATE, :BEDS => sum)

# reorder states
sort!(beds_data, :STATE)

# select beds column
beds = beds_data.BEDS_sum;
beds = Float32.(beds);

In [6]:
n, t = 20, 40
beds_small, forecast_small = beds[1:n], forecast[1:n,1:t];

In [7]:
model = patient_allocation(beds_small, forecast_small);

Academic license - for non-commercial use only
Academic license - for non-commercial use only
Gurobi Optimizer version 9.0.1 build v9.0.1rc0 (mac64)
Optimize a model with 17600 rows, 16800 columns and 12800800 nonzeros
Model fingerprint: 0x2fbaa22e
Coefficient statistics:
  Matrix range     [1e+00, 1e+00]
  Objective range  [1e+00, 1e+00]
  Bounds range     [0e+00, 0e+00]
  RHS range        [2e+02, 9e+04]

Concurrent LP optimizer: primal simplex, dual simplex, and barrier
Showing barrier log only...

Presolve removed 1180 rows and 800 columns
Presolve time: 4.83s
Presolved: 16420 rows, 16000 columns, 12183980 nonzeros


Barrier performed 0 iterations in 5.88 seconds
Barrier solve interrupted - model solved by another algorithm


Solved with dual simplex
Solved in 62 iterations and 6.13 seconds
Optimal objective  0.000000000e+00


In [8]:
sent = value.(model[:sent]);
outcomes = DataFrame(
    state=state_list.Abbreviation[1:n],
    total_sent=sum(sent, dims=[2,3])[:],
    total_received=sum(sent, dims=[1,3])[:]
)

Unnamed: 0_level_0,state,total_sent,total_received
Unnamed: 0_level_1,String,Float64,Float64
1,AL,251.415,1780.94
2,AK,0.0,2907.75
3,AZ,0.0,0.0
4,AR,0.0,0.0
5,CA,0.0,1769.77
6,CO,0.0,0.0
7,CT,5735.75,0.0
8,DE,0.0,0.0
9,DC,0.0,0.0
10,FL,0.0,0.0
