In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.special import comb

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import copy

import plotly.graph_objects as go
from plotly.subplots import make_subplots

class Parameters_SEIR_pool_testing:
    
    def __init__(self, R0, Npool, r_v, t_E, t_I, t_P, t_S, t_Q):
        
        self.R0 = R0
        self.Npool = Npool
        self.r_v = r_v
        
        self.t_E = t_E
        self.sigma = 1 / t_E
        
        self.t_I = t_I
        self.gamma = 1 / t_I
        
        self.t_P = t_P
        self.Omega = 1 / t_P
        
        self.t_S = t_S
        self.omega = 1 / t_S
        
        self.t_Q = t_Q
        self.delta = 1 / t_Q
        
        self.beta = R0 * self.gamma
        
        if r_v * t_P >= 1:
            
            print("r_v * t_P must be less than Npool")
            
        else:
            
            self.r_p = r_v / (1  - r_v * t_P)
            
    def change_attributes_and_return_new_obj(self, attribute_str_attribute_val_pairs):
        
        # Create a shallow copy of the current object
        new_obj = copy.deepcopy(self)
        
        for attribute_str_attribute_val_pair in attribute_str_attribute_val_pairs:
            
            attribute_str, attribute_val = attribute_str_attribute_val_pair
            
            if attribute_str == "R0":
                
                new_obj.R0 = attribute_val
                
            if attribute_str == "Npool":
                
                new_obj.Npool = attribute_val
                
            if attribute_str == "r_v":
                
                new_obj.r_v = attribute_val
                
            if attribute_str == "t_E":
                
                new_obj.t_E = attribute_val
                
            if attribute_str == "t_I":
                
                new_obj.t_I = attribute_val
                
            if attribute_str == "t_P":
                
                new_obj.t_P = attribute_val
                
            if attribute_str == "t_S":
                
                new_obj.t_S = attribute_val
                
            if attribute_str == "t_Q":
                
                new_obj.t_Q = attribute_val    
                
            if new_obj.r_v * new_obj.t_P >= 1:
                
                print("r_v * t_P must be less than Npool")
                
            else:
                
                new_obj.r_p = new_obj.r_v / (1  - new_obj.r_v * new_obj.t_P)
    
        # Return the copied object with the modified attribute
        return new_obj
    
class SEIR_pool_testing:
            
    def prob_positive_pool_negative(self, x, Npool):
        
        # This returns the probability of an individual who isn't infected themselves of  being in a pool test
        # with Npool - 1 other individuals that returns a negative result given the prevalence is x
        # Note the prevalance x is really the prevalence of individuals that would trigger a positive
        # pool test rather than the prevalence of infection. The I compartment in this model
        
        return 1 - (1 - x) ** (Npool - 1)
    
    def prob_negative_pool_negative(self, x, Npool):
        
        # This returns the probability of an individual who isn't infected themselves of  being in a pool test
        # with Npool - 1 other individuals that returns a positive result given the prevalence is x
        # Note the prevalance x is really the prevalence of individuals that would trigger a positive
        # pool test rather than the prevalence of infection. The I compartment in this model
        
        return (1 - x) ** (Npool - 1)
        
    def _fun(self, t, y, beta, Npool, r_p, sigma, gamma, Omega, omega, delta):
        
        S, Sppos, Spneg, Sneg = y[0:4]
        E, Eppos, Epneg, Epos, EQ = y[4:9]
        I, Ippos, Ipneg, Ipos, IQ = y[9:14]
        R, Rppos, Rpneg, Rpos, Rneg, RQ = y[14:20]
        
        x = I/(S + E + I + R)
        
        d_S_dt = omega * Sneg + Npool * Omega * Spneg - beta * S * (I + Ippos + Ipneg) - Npool * r_p * S
        d_Sppos_dt = Npool * r_p * self.prob_positive_pool_negative(x, Npool) * S - beta * Sppos * (I + Ippos + Ipneg) - Npool * Omega * Sppos
        d_Spneg_dt = Npool * r_p * self.prob_negative_pool_negative(x, Npool) * S - beta * Spneg * (I + Ippos + Ipneg) - Npool * Omega * Spneg
        d_Sneg_dt = Npool * Omega * Sppos - omega * Sneg
        
        d_E_dt = beta * S * (I + Ippos + Ipneg) + Npool * Omega * Epneg + delta * EQ - Npool * r_p * E - sigma * E
        d_Eppos_dt = beta * Sppos * (I + Ippos + Ipneg) + Npool * self.prob_positive_pool_negative(x, Npool) * r_p * E - Npool * Omega * Eppos - sigma * Eppos
        d_Epneg_dt = beta * Spneg * (I + Ippos + Ipneg) + Npool * self.prob_negative_pool_negative(x, Npool) * r_p * E - Npool * Omega * Epneg - sigma * Epneg
        d_Epos_dt = Npool * Omega * Eppos - omega * Epos - sigma * Epos
        d_EQ_dt = omega * Epos - delta * EQ - sigma * EQ
        
        d_I_dt = sigma * E + Npool * Omega * Ipneg + delta * IQ - Npool * r_p * I - gamma * I
        d_Ippos_dt = sigma * Eppos + Npool * r_p * I - Npool * Omega * Ippos - gamma * Ippos
        d_Ipneg_dt = sigma * Epneg - Npool * Omega * Ipneg - gamma * Ipneg
        d_Ipos_dt = sigma * Epos + Npool * Omega * Ippos - omega * Ipos - gamma * Ipos
        d_IQ_dt = sigma * EQ + omega * Ipos - delta * IQ - gamma * IQ
        
        d_R_dt = gamma * I + Npool * Omega * Rpneg + omega * Rneg + delta * RQ - Npool * r_p * R
        d_Rppos_dt = gamma * Ippos + Npool * self.prob_positive_pool_negative(x, Npool) * r_p * R - Npool * Omega * Rppos
        d_Rpneg_dt = gamma * Ipneg + Npool * self.prob_negative_pool_negative(x, Npool) * r_p * R - Npool * Omega * Rpneg
        d_Rpos_dt = gamma * Ipos - omega * Rpos
        d_Rneg_dt = Npool * Omega * Rppos - omega * Rneg
        d_RQ_dt = gamma * IQ + omega * Rpos - delta * RQ
        
        dydt = [d_S_dt, d_Sppos_dt, d_Spneg_dt, d_Sneg_dt, d_E_dt, d_Eppos_dt, d_Epneg_dt, d_Epos_dt, d_EQ_dt, d_I_dt, d_Ippos_dt, d_Ipneg_dt, d_Ipos_dt, d_IQ_dt, d_R_dt, d_Rppos_dt, d_Rpneg_dt, d_Rpos_dt, d_Rneg_dt, d_RQ_dt]
        
        return dydt
    
    def generate_solution(self, t_linspace, param_obj, y0 = None, method = 'LSODA'):
        
        # This generates a solution over a linear time space t_linspace
        
        if y0 == None:
            
            y0 = [7999999/8000000 * param_obj.Omega / (param_obj.Omega + param_obj.r_p), 0, 7999999/8000000 * param_obj.r_p / (param_obj.Omega + param_obj.r_p), 0, 1/8000000  * param_obj.Omega / (param_obj.Omega + param_obj.r_p), 0, 1/8000000 * param_obj.r_p / (param_obj.Omega + param_obj.r_p), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        
        t0, tf = t_linspace[0], t_linspace[-1]
        
        parameter_args = (param_obj.beta, param_obj.Npool, param_obj.r_p, param_obj.sigma, param_obj.gamma, param_obj.Omega, param_obj.omega, param_obj.delta)
        
        return solve_ivp(self._fun, (t0, tf), y0, method, args = parameter_args, t_eval = t_linspace)
    
    def generate_disease_compartments_sol(self, tau_linspace, param_obj):
        
        # This generates the probabilities that a person will be in each of the given disease compartments 
        # at a certain age-of-infection tau
            
        y0 = [0, 0, 0, 0, param_obj.Omega / (param_obj.Omega + param_obj.r_p), 0, param_obj.r_p / (param_obj.Omega + param_obj.r_p), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        
        t0, tf = tau_linspace[0], tau_linspace[-1]
        
        parameter_args = (param_obj.beta, param_obj.Npool, param_obj.r_p, param_obj.sigma, param_obj.gamma, param_obj.Omega, param_obj.omega, param_obj.delta)
        
        return solve_ivp(self._fun, (t0, tf), y0, method = 'LSODA', args = parameter_args, t_eval = tau_linspace)    
    
    def S(self, compartment_string, sol_y):
        
        if compartment_string == " ":
            
            return sol_y[0]
        
        elif compartment_string == 'ppos':
            
            return sol_y[1]
        
        elif compartment_string == 'pneg':
            
            return sol_y[2]
        
        elif compartment_string == 'neg':
            
            return sol_y[3]
    
    def E(self, compartment_string, sol_y):
        
        if compartment_string == " ":
            
            return sol_y[4]
        
        elif compartment_string == 'ppos':
            
            return sol_y[5]
        
        elif compartment_string == 'pneg':
            
            return sol_y[6]
        
        elif compartment_string == 'pos':
            
            return sol_y[7]
        
        elif compartment_string == 'Q':
            
            return sol_y[8]             
        
    def I(self, compartment_string, sol_y):
            
        if compartment_string == " ":
            
            return sol_y[9]
        
        elif compartment_string == 'ppos':
            
            return sol_y[10]
        
        elif compartment_string == 'pneg':
            
            return sol_y[11]
        
        elif compartment_string == 'pos':
            
            return sol_y[12]
        
        elif compartment_string == 'Q':
            
            return sol_y[13]
        
        elif compartment_string == 'total':
            
            return [a + b + c + d + e for (a, b, c, d, e) in zip(sol_y[9], sol_y[10], sol_y[11], sol_y[12], sol_y[13])]
        
        elif compartment_string == 'mixing':
            
            return [a + b + c for (a, b, c) in zip(sol_y[9], sol_y[10], sol_y[11])]
            
    def R(self, compartment_string, sol):

        if compartment_string == " ":
            
            return sol[14]
        
        elif compartment_string == 'ppos':
            
            return sol[15]
        
        elif compartment_string == 'pneg':
            
            return sol[16]
        
        elif compartment_string == 'pos':
            
            return sol[17]
        
        elif compartment_string == 'neg':
            
            return sol[18]
        
        elif compartment_string == 'Q':
            
            return sol[19]               
        
    def E_proportion_infected_population(self, sol):
        
        return [E / (E + I) for (E, I) in zip(self.E(" ", sol), self.I(" ", sol))]
    
    def I_proportion_infected_population(self, sol):
            
        return [I / (E + I) for (E, I) in zip(self.E(" ", sol), self.I(" ", sol))]

    def _probability_I_in_pool(self, S, E, I, R, Npool):
        
        # Given S, E, I and R at a particular time, this returns the probability that a pool will contain 
        # at least one member from I. Mainly used to calculate the pool sensitivity
        
        return sum([comb(Npool, i) * (I / (S + E + I + R)) ** i * ((S + E + R) / (S + E + I + R)) ** (Npool - i) for i in range(1, Npool + 1)])
    
    def _probability_E_or_I_in_pool(self, S, E, I, R, Npool):
        
        # Given S, E, I and R at a particular time, this returns the probability that a pool will contain
        # at least one member from E or I. Mainly used to calculate the pool sensitivity.
        
        return sum([comb(Npool, i) * ((E + I) /(S + E + I + R)) ** i * ((S + R) / (S + E + I + R)) ** (Npool - i) for i in range(1, Npool + 1)])
    
    def sensitivity(self, Npool, sol):
        
        # Returns the sensitivity of the pool tests throughout the epidemic.
        # Sensitivity is the probability that a pool test will return a positive
        # result given there is someone infected in the pool.
        
        
        sensitivity = [self._probability_I_in_pool(S, E, I, R, Npool) / self._probability_E_or_I_in_pool(S, E, I, R, Npool) for (S, E, I, R) in zip(self.S(" ", sol), self.E(" ", sol), self.I(" ", sol), self.R(" ", sol))]
        
          
        #[np.exp( np.log((S + E + I + R) ** Npool - (S + E + R) ** Npool) - np.log((S + E + I + R) ** Npool - (S + R) ** Npool) ) for (S, E, I, R) in zip(self.S(" ", sol), self.E(" ", sol), self.I(" ", sol), self.R(" ", sol))]
        
        
        
        return sensitivity
    
    def r_SEIR_traditional(self, param_obj):
        
        return (1/2) * (- (param_obj.sigma + param_obj.gamma) + np.sqrt((param_obj.sigma + param_obj.gamma)**2 + 4 * param_obj.sigma * (param_obj.beta - param_obj.gamma)))  
    
    def E_proportion_infected_population_limiting_SEIR_traditional(self, param_obj):
        
        r = self.r_SEIR_traditional(param_obj)
        
        return (param_obj.gamma + r) / (param_obj.sigma + param_obj.gamma + r)

    def I_proportion_infected_population_limiting_SEIR_traditional(self, param_obj):
        
        r = self.r_SEIR_traditional(param_obj)
        
        return (param_obj.sigma) / (param_obj.sigma + param_obj.gamma + r)

    def prob_step(self, compartment_from, compartment_to, param_obj):
        
        if compartment_from == "E":
            
            if compartment_to == "Epneg":
                
                return param_obj.Npool * param_obj.r_p / (param_obj.Npool * param_obj.r_p + param_obj.sigma)
            
            elif compartment_to == "I":
                
                return param_obj.sigma / (param_obj.Npool * param_obj.r_p + param_obj.sigma)
            
            else:
                
                return 0
            
        elif compartment_from == "Epneg":
            
            if compartment_to == "E":
                
                return param_obj.Npool * param_obj.Omega / (param_obj.Npool * param_obj.Omega + param_obj.sigma)
            
            elif compartment_to == "Ipneg":
                
                return param_obj.sigma / (param_obj.Npool * param_obj.Omega + param_obj.sigma)
            
            else:
                
                return 0
            
        elif compartment_from == "I":
        
            if compartment_to == 'Ippos':
                
                return param_obj.Npool * param_obj.r_p / (param_obj.Npool * param_obj.r_p + param_obj.gamma)
            
            elif compartment_to == 'R':
                
                return param_obj.gamma / (param_obj.Npool * param_obj.r_p + param_obj.gamma)
            
            else:
                
                return 0
            
        elif compartment_from == "Ippos":
        
            if compartment_to == "Ipos":
                
                return param_obj.Npool * param_obj.Omega / (param_obj.Npool * param_obj.Omega + param_obj.gamma)
            
            elif compartment_to == "Rppos":
                
                return param_obj.gamma / (param_obj.Npool * param_obj.Omega + param_obj.gamma)
            
            else:
                
                return 0
        
        elif compartment_from == "Ipneg":
            
            if compartment_to == "I":
                
                return param_obj.Npool * param_obj.Omega / (param_obj.Npool * param_obj.Omega + param_obj.gamma)
            
            elif compartment_to == "Rpneg":
                
                return param_obj.gamma / (param_obj.Npool * param_obj.Omega + param_obj.gamma)
            
            else:
                
                return 0
            
        elif compartment_from == "Ipos":
            
            if compartment_to == "IQ":
                
                return param_obj.omega / (param_obj.omega + param_obj.gamma)
            
            elif compartment_to == "Rpos":
                
                return param_obj.gamma / (param_obj.omega + param_obj.gamma)
            
            else:
                
                return 0            
        
        elif compartment_from == "IQ":
            
            if compartment_to == "I":
                
                return param_obj.delta / (param_obj.delta + param_obj.gamma)
            
            elif compartment_to == "RQ":
                
                return param_obj.gamma / (param_obj.delta + param_obj.gamma)
            
            else:
                
                return 0
            
    def prob_chain(self, compartments, param_obj):
        
        chain_probability = 1
        
        for i in range(len(compartments)-1):
            
            chain_probability = chain_probability * self.prob_step(compartments[i], compartments[i + 1], param_obj)
            
        return chain_probability
    
    def prob_return(self, compartment, param_obj):
        
        if compartment == "E":
            
            return self.prob_chain(['E', 'Epneg', 'E'], param_obj)
        
        elif compartment == 'Epneg':
            
            return self.prob_chain(['Epneg', 'E', 'Epneg'], param_obj)
        
        elif compartment == 'I':
            
            return self.prob_chain(['I', 'Ippos', 'Ipos', 'IQ', 'I'], param_obj)
        
        elif compartment == 'Ippos':
            
            return self.prob_chain(['Ippos', 'Ipos', 'IQ', 'I', 'Ippos'], param_obj)   
        
        elif compartment == 'Ipneg':
            
            return 0
        
        elif compartment == 'Ipos':
            
            return self.prob_chain(['Ipos', 'IQ', 'I', 'Ippos', 'Ipos'], param_obj)
        
        elif compartment == 'IQ':
            
            return self.prob_chain(['IQ', 'I', 'Ippos', 'Ipos', 'IQ'], param_obj)
        
    def prob_advance(self, compartment, param_obj):
        
        if compartment == "E":
            
            return self.prob_step("E", "I", param_obj) * 1 / (1 - self.prob_return("E", param_obj))
        
        elif compartment == "Epneg":
            
            return self.prob_step("Epneg", "Ipneg", param_obj) * 1 / (1 - self.prob_return("Epneg", param_obj))
        
        elif compartment == 'I':
            
            return self.prob_step('I', 'R', param_obj) * 1 / (1 - self.prob_return('I', param_obj))
        
        elif compartment == 'Ippos':
            
            return self.prob_step('Ippos', 'Rppos', param_obj) * 1 / (1 - self.prob_return('Ippos', param_obj))
        
        elif compartment == 'Ipneg':
            
            return self.prob_step('Ipneg', 'Rpneg', param_obj) * 1 / (1 - self.prob_return('Ipneg', param_obj))
        
        elif compartment == 'Ipos':
            
            return self.prob_step('Ipos', 'Rpos', param_obj) * 1 / (1 - self.prob_return('Ipos', param_obj))
        
        elif compartment == 'IQ':
            
            return self.prob_step('IQ', 'RQ', param_obj) * 1 / (1 - self.prob_return('IQ', param_obj))
                
    def prob_infected_in(self, compartment, param_obj):
        
        if compartment == 'S':
            
            return param_obj.Omega / (param_obj.Omega + param_obj.r_p)
        
        elif compartment == 'Spneg':
            
            return param_obj.r_p / (param_obj.Omega + param_obj.r_p)
        
        else:
            
            return None
        
    def expected_sojourn_time(self, compartment, param_obj):
        
        if compartment == 'I':
            
            return 1 / (param_obj.Npool * param_obj.r_p + param_obj.gamma) * 1 / (1 - self.prob_return("I", param_obj))
        
        elif compartment == 'Ippos':
            
            return 1 / (param_obj.Npool * param_obj.Omega + param_obj.gamma) * 1 / (1 - self.prob_return("Ippos", param_obj))
        
        elif compartment == 'Ipneg':
            
            return 1 / (param_obj.Npool * param_obj.Omega + param_obj.gamma) * 1 / (1 - self.prob_return("Ipneg", param_obj))

    def prob_hit(self, compartment_from, compartment_to, param_obj):
        
        if compartment_from == "E":
            
            if compartment_to == 'I':
                
                return self.prob_advance("E", param_obj) + self.prob_step('E', 'Epneg', param_obj) * self.prob_advance("Epneg", param_obj) * self.prob_chain(["Ipneg", "I"], param_obj)
            
            elif compartment_to == 'Ippos':
                
                return self.prob_advance("E", param_obj) * self.prob_step('I', "Ippos", param_obj) + self.prob_step('E', "Epneg", param_obj) * self.prob_advance("Epneg", param_obj) * self.prob_chain(['Ipneg', 'I', 'Ippos'], param_obj) 
            
            elif compartment_to == 'Ipneg':
                
                return self.prob_step('E', 'Epneg', param_obj) * self.prob_advance('Epneg', param_obj)
            
            elif compartment_to == 'Ipos':
                
                return self.prob_advance('E', param_obj) * self.prob_chain(['I', 'Ippos', 'Ipos'], param_obj) + self.prob_step('E', 'Epneg', param_obj) * self.prob_advance('Epneg', param_obj) * self.prob_chain(['Ipneg', 'I', 'Ippos', 'Ipos'], param_obj)
            
            elif compartment_to == 'IQ':
                
                return self.prob_advance('E', param_obj) * self.prob_chain(['I', 'Ippos', 'Ipos', 'IQ'], param_obj) + self.prob_step('E', 'Epneg', param_obj) * self.prob_advance('Epneg', param_obj) * self.prob_chain(['Ipneg', 'I', 'Ippos', 'Ipos', 'IQ'], param_obj)
        
        elif compartment_from == 'Epneg':
            
            if compartment_to == 'I':
                
                return self.prob_advance("Epneg", param_obj) * self.prob_step('Ipneg', "I", param_obj) + self.prob_step('Epneg', "E", param_obj) * self.prob_advance("E", param_obj)
            
            elif compartment_to == 'Ippos':
                
                return self.prob_advance("Epneg", param_obj) * self.prob_chain(["Ipneg", "I", "Ippos"], param_obj) + self.prob_step('Epneg', 'E', param_obj) * self.prob_advance("E", param_obj) * self.prob_step('I', "Ippos", param_obj) 
            
            elif compartment_to == 'Ipneg':
                
                return self.prob_advance('Epneg', param_obj)
            
            elif compartment_to == 'Ipos':
                
                return self.prob_advance('Epneg', param_obj) * self.prob_chain(['Ipneg', 'I', 'Ippos', 'Ipos'], param_obj) + self.prob_step('Epneg', 'E', param_obj) * self.prob_chain(['I', 'Ippos', 'Ipos'], param_obj)
            
            elif compartment_to == 'IQ':
                
                return self.prob_advance('Epneg', param_obj) * self.prob_chain(['Ipneg', 'I', 'Ippos', 'Ipos', 'IQ'], param_obj) + self.prob_step('Epneg', 'E', param_obj) * self.prob_advance('E', param_obj) * self.prob_chain(['I', 'Ippos', 'Ipos', 'IQ'], param_obj)
        
    def generate_R0_p(self, param_obj):
        
        expected_infected_time_individual_in_E = self.prob_hit('E', 'I', param_obj) * self.expected_sojourn_time('I', param_obj) + self.prob_hit('E', 'Ippos', param_obj) * self.expected_sojourn_time('Ippos', param_obj) + self.prob_hit('E', 'Ipneg', param_obj) * self.expected_sojourn_time('Ipneg', param_obj)
        expected_infected_time_individual_in_Epneg = self.prob_hit('Epneg', 'I', param_obj) * self.expected_sojourn_time('I', param_obj) + self.prob_hit('Epneg', 'Ippos', param_obj) * self.expected_sojourn_time('Ippos', param_obj) + self.prob_hit('Epneg', 'Ipneg', param_obj) * self.expected_sojourn_time('Ipneg', param_obj)
        
        return param_obj.beta * (
            self.prob_infected_in('S', param_obj) * expected_infected_time_individual_in_E +
            self.prob_infected_in('Spneg', param_obj) * expected_infected_time_individual_in_Epneg
        )
                                                             
    def final_size(self, R0):
        
        # Returns the expected final size of the traditional SEIR epidemic
        
        f = lambda x: 1 - x - np.exp(-R0 * x)
        
        return optimize.newton(f, 0.99)

    def plot_R0_vs_r_v_varying_Npool(self, N_pool_vals, other_params_obj):
        
        fig = go.Figure()
        
        r_v_vals = np.linspace(0.0, 0.2)
        
        for Npool in Npool_vals:
            
            param_objs = [Parameters_SEIR_pool_testing(other_params_obj.R0, Npool, r_v, other_params_obj.t_E, other_params_obj.t_I, other_params_obj.t_P, other_params_obj.t_S, other_params_obj.t_Q) for r_v in r_v_vals]
            R0_vals = [self.generate_R0_p(p_obj) for p_obj in param_objs]
            
            fig.add_trace(go.Scatter(x = r_v_vals,
                                     y = R0_vals,
                                     name = f"Pool size = {Npool}"
                                    )
                         )
            
        fig.update_layout(title = "Effect of increased testing on the pool reproductive number for differing pool sizes",
                          xaxis_title = r"$r_v$",
                          yaxis_title = r"$\mathbb{R}_0$"
                         )
        #fig.write_image('R0_different_testing_regimes.png', scale = 4)
        
        fig.show()

    def plot_R0_vs_t_P_varying_Npool(self, t_P_vals, Npool_vals, other_params_obj, save_fig = False):
        
        fig = go.Figure()     
        
        for Npool in Npool_vals:
            
            param_objs = [Parameters_SEIR_pool_testing(other_params_obj.R0, Npool, other_params_obj.r_v, other_params_obj.t_E, other_params_obj.t_I, t_P, other_params_obj.t_S, other_params_obj.t_Q) for t_P in t_P_vals]
            R0_vals = [self.generate_R0_p(p_obj) for p_obj in param_objs]
            
            fig.add_trace(go.Scatter(x = t_P_vals,
                                     y = R0_vals,
                                     name = f"Pool size = {Npool}"
                                    )
                         )
            
        
        fig.update_layout(title = "Effect of pool turn-about time on the pool reproductive number",
                          xaxis_title = r"$t_P$",
                          yaxis_title = r"$\mathbb{R}_0^P$"
                         )
        
        if save_fig == True:
            
            fig.write_image('Images/R0_vs_t_P_varying_Npool.png', scale = 4)
        
        fig.show()
     
    def plot_R0_vs_t_S_varying_Npool(self, t_S_vals, Npool_vals, other_params_obj, save_fig = False):
        
        fig = go.Figure()
        
        for Npool in Npool_vals:
            
            param_objs = [Parameters_SEIR_pool_testing(other_params_obj.R0, Npool, other_params_obj.r_v, other_params_obj.t_E, other_params_obj.t_I, other_params_obj.t_P, t_S, other_params_obj.t_Q) for t_S in t_S_vals]
            R0_vals = [self.generate_R0_p(p_obj) for p_obj in param_objs]
            
            fig.add_trace(go.Scatter(x = t_S_vals,
                                     y = R0_vals,
                                     name = f"Pool size = {Npool}"
                                    )
                         )
            
        
        fig.update_layout(title = "Effect of standard test turn-about time on the pool reproductive number",
                          xaxis_title = r"$t_S$",
                          yaxis_title = r"$\mathbb{R}_0^P$"
                         )
        
        if save_fig == True:
            fig.write_image('Images/R0_vs_t_S_varying_Npool.png', scale = 4)
        
        fig.show()

    def plot_R0_vs_t_Q_varying_Npool(self, t_Q_vals, Npool_vals, other_params_obj, save_fig = False):
        
        fig = go.Figure()
        
        for Npool in Npool_vals:
            
            param_objs = [Parameters_SEIR_pool_testing(other_params_obj.R0, Npool, other_params_obj.r_v, other_params_obj.t_E, other_params_obj.t_I, other_params_obj.t_P, other_params_obj.t_S, t_Q) for t_Q in t_Q_vals]
            R0_vals = [self.generate_R0_p(p_obj) for p_obj in param_objs]
            
            fig.add_trace(go.Scatter(x = t_Q_vals,
                                     y = R0_vals,
                                     name = f"Pool size = {Npool}"
                                    )
                         )
            
        
        fig.update_layout(title = "Effect of quarantine time on the pool reproductive number",
                          xaxis_title = r"$t_Q$",
                          yaxis_title = r"$\mathbb{R}_0^P$"
                         )
        
        if save_fig == True:
            fig.write_image('Images/R0_vs_t_Q_varying_Npool.png', scale = 4)
        
        fig.show()        

    def plot_R0_vs_t_tot_varying_Npool(self, t_tot_vals, Npool_vals, other_params_obj, save_fig = False):
        
        fig = go.Figure()
        
        latent_proportion = other_params_obj.t_E / (other_params_obj.t_E + other_params_obj.t_I)
        
        for Npool in Npool_vals:
            
            param_objs = [Parameters_SEIR_pool_testing(other_params_obj.R0, Npool, other_params_obj.r_v, latent_proportion * t_tot, (1-latent_proportion) * t_tot, other_params_obj.t_P, other_params_obj.t_S, other_params_obj.t_Q) for t_tot in t_tot_vals]
            R0_vals = [self.generate_R0_p(p_obj) for p_obj in param_objs]
            
            fig.add_trace(go.Scatter(x = t_tot_vals,
                                     y = R0_vals,
                                     name = f"Pool size = {Npool}"
                                    )
                         )
            
        
        fig.update_layout(title = "Effect of increasing total infected time on the pool reproductive number",
                          xaxis_title = r"$t_{tot}$",
                          yaxis_title = r"$\mathbb{R}_0^P$"
                         )
        
        if save_fig == True:
            
            fig.write_image('Images/R0_vs_t_tot_varying_Npool.png', scale = 4)
        
        fig.show()
        
    def plot_R0_vs_latent_proportion_varying_Npool(self, Npool_vals, other_params_obj, save_fig = False):
        
        fig = go.Figure()
        
        latent_proportion_vals = np.linspace(0.01, 0.99, 1000)
        t_tot = other_params_obj.t_E + other_params_obj.t_I
        
        for Npool in Npool_vals:
            
            param_objs = [Parameters_SEIR_pool_testing(other_params_obj.R0, Npool, other_params_obj.r_v, latent_proportion * t_tot, (1-latent_proportion) * t_tot, other_params_obj.t_P, other_params_obj.t_S, other_params_obj.t_Q) for latent_proportion in latent_proportion_vals]
            R0_vals = [self.generate_R0_p(p_obj) for p_obj in param_objs]
            
            fig.add_trace(go.Scatter(x = latent_proportion_vals,
                                     y = R0_vals,
                                     name = f"Pool size = {Npool}"
                                    )
                         )
            
        
        fig.update_layout(title = "Effect of increasing the proportion of infected time spend latent on the pool reproductive number",
                          xaxis_title = "Latent proportion",
                          yaxis_title = r"$\mathbb{R}_0^P$"
                         )
        
        if save_fig == True:
            
            fig.write_image('Images/R0_vs_latent_proportion_varying_Npool.png', scale = 4)
        
        fig.show()
 
    def interactive_plot_pool_R0(self, t_P, t_S, t_Q, Npool, t_tot, latent_time_ratio, other_params_obj):
        
        t_E =  (t_tot * latent_time_ratio)
        t_I =  (t_tot * (1 - latent_time_ratio))
        beta = other_params_obj.R0 / t_I
        delta = 1 / t_Q
        
        r_v_max = 1 / t_P - 0.01
        r_v_values = np.linspace(0, r_v_max, 1000)
        
        param_objs = [Parameters_SEIR_pool_testing(other_params_obj.R0, Npool, r_v, t_E, t_I, t_P, t_S, t_Q) for r_v in r_v_values]
          
        pool_R0_values = [self.generate_R0_p(param_obj)for param_obj in param_objs]
        
        fig, axes = plt.subplots()
        
        axes.plot(r_v_values, pool_R0_values, label = r'$R_0^p$')
        axes.set_xlim(0.0001, 0.2)
        axes.set_ylim((1, other_params_obj.R0+1))
        axes.set_title(r"Interactive plot of $\mathbb{R}_0^p$")
        axes.set_xlabel("Tests per unit time")
        axes.legend()
        
        plt.show()

    def interact_plot_sensitivity_proportion_infected_populations(self, t_linspace, R0, Npool, r_v, t_E, t_I, t_P, t_S, t_Q, other_params_obj):
           
        param_obj = other_params_obj.change_attributes_and_return_new_obj([['R0', R0], ['Npool', Npool], ['r_v', r_v], ['t_E', t_E], ['t_I', t_I], ['t_P', t_P], ['t_S', t_S], ['t_Q', t_Q]])
        
        sol = self.generate_solution(t_linspace, param_obj)
    
        fig, axes = plt.subplots(1, 2)
    
        axes[0].plot(t_linspace, self.sensitivity(Npool, sol.y))
        axes[0].set_title("Pool sensitivity")
        axes[0].set_xlabel("t")
        axes[0].set_ylim((0, 1))
        
        axes[1].plot(t_linspace, self.E_proportion_infected_population(sol.y), label = "E", color = "blue")
        
        axes[1].scatter(5, self.E_proportion_infected_population_limiting_SEIR_traditional(param_obj),  color = "blue")
        
        axes[1].plot(t_linspace, self.I_proportion_infected_population(sol.y), label = "I", color = "red")
        
        axes[1].scatter(5, self.I_proportion_infected_population_limiting_SEIR_traditional(param_obj),  color = "red")
        
        axes[1].set_title("% of E + I that is E and I")
        
        axes[1].legend()
        axes[1].set_xlabel("t")
        axes[1].set_ylim((0, 1))
        
        plt.tight_layout()
        plt.show()        

    def plot_sensitivity(self, t_linspace, param_obj):
        
        sol = self.generate_solution(t_linspace, param_obj)
        
        fig = go.Figure(go.Scatter(x = t_linspace,
                                   y = self.sensitivity(param_obj.Npool, sol.y)
                                  )
                       )
        
        fig.update_layout(title = "Plot of pool sensitivity",
                          xaxis_title = r"$t$"
                         )
        
        fig.show()

    def plot_sensitivity_total_I_E_proportion_I_proportion_total_infected(self, t_linspace, param_obj, save_fig = False):
        
        sol = self.generate_solution(t_linspace, param_obj)
        
        fig = make_subplots(rows = 1, cols = 2)
        
        fig.add_trace(go.Scatter(x = t_linspace,
                                 y = self.sensitivity(param_obj.Npool, sol.y),
                                 name = "Sensitivity"
                                ),
                      row = 1, col = 1
                     )

        
        fig.add_trace(go.Scatter(x = t_linspace,
                                 y = self.I(" ", sol.y),
                                 name = r"$\Sigma I$"
                                ),
                      row = 1, col = 1
                     )
        
        fig.add_trace(go.Scatter(x = t_linspace,
                                 y = self.E_proportion_infected_population(sol.y),
                                 name = "E proportion"
                                ),
                      row = 1, col = 2
                     )
        
        fig.add_trace(go.Scatter(x = t_linspace,
                                 y = self.I_proportion_infected_population(sol.y),
                                 name = "I proportion"
                                ),
                      row = 1, col = 2
                     )
        
        fig.update_yaxes(range=[0,1], row = 1, col = 1)
        
        fig.update_xaxes(title_text = r'$t$', row = 1, col = 1)
        fig.update_xaxes(title_text = r'$t$', row = 1, col = 2)
        
        if save_fig == True:
            
            fig.write_image("Images/plot_sensitivity_total_I_E_proportion_I_proportion_total_infected.png", scale = 4)
        
        fig.show()   
        
    def plot_E_proportion_and_I_proportion_total_infected(self, t_linspace, param_obj, save_fig = False):
        
        sol = self.generate_solution(t_linspace, param_obj)
        
        fig = go.Figure(go.Scatter(x = t_linspace,
                                 y = self.E_proportion_infected_population(sol.y),
                                 name = "E proportion"
                                )
                     )
        
        fig.add_trace(go.Scatter(x = t_linspace,
                                 y = self.I_proportion_infected_population(sol.y),
                                 name = "I propotion"
                                )
                     )
        
        fig.update_layout(xaxis_title = r"$t$")
        
        if save_fig == True:
            
            fig.write_image("Images/plot_E_proportion_and_I_proportion_total_infected.png", scale = 4)
            
        fig.show()

    def generate_infectivity_curve(self, tau_vals, param_obj):
        
        # This plots the transmission occuring at age-of-infection tau
        
        y0 = [0, 0, 0, 0, param_obj.Omega / (param_obj.Omega + param_obj.r_p), 0, param_obj.r_p / (param_obj.Omega + param_obj.r_p), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        
        sol_y = self.generate_solution(tau_vals, param_obj, y0).y
        
        return [param_obj.beta * I_mix for I_mix in self.I("mixing", sol_y)]

    def plot_infectivity_curves_varying_Npool(self, tau_vals, Npool_vals, param_obj, save_fig = False):  
        
        fig = go.Figure(go.Scatter(x = tau_vals,
                                   y =  self.generate_infectivity_curve(tau_vals, param_obj.change_attributes_and_return_new_obj([['Npool', 0]])),
                                   name = f"No testing, M = 1"
                                  )
                       )
        
        for Npool in Npool_vals:
            
            new_param_obj = param_obj.change_attributes_and_return_new_obj([['Npool', Npool]])
            
            infectivity_curve = self.generate_infectivity_curve(tau_vals, new_param_obj)
            
            fig.add_trace(go.Scatter(x = tau_vals, 
                                     y = infectivity_curve,
                                     name = f"Npool = {Npool}, M = {round(self.generate_R0_p(new_param_obj) / param_obj.R0, 1)}",
                                    )
                         )
            
        fig.update_layout(xaxis_title = r'$\tau$')
            
        if save_fig == True:
            
            fig.write_image("Images/plot_infectivity_curves_varying_Npool.png", scale = 4)
            
        fig.show()
    
    def plot_probability_isolated_given_age_of_infection_varying_Npool(self, tau_vals, Npool_vals, param_obj, save_fig = False):
        
        # Plots the infectivity given a age-of-infection with no testing involved and the plots the probability
        # that one would be isolated at a given age-of-infection given different pool sizes
        
        colours = np.linspace(0, 1, len(Npool_vals))
        
        fig = go.Figure(go.Scatter(x = tau_vals,
                                   y = self.generate_infectivity_curve(tau_vals, param_obj.change_attributes_and_return_new_obj([['r_v', 0]])),
                                   name = 'Infectivity',
                                   line = dict(color = 'blue')
                                  )
                       )
        
        for Npool in Npool_vals:
            
            sol_y = self.generate_disease_compartments_sol(tau_vals, param_obj.change_attributes_and_return_new_obj([['Npool', Npool]])).y
            
            fig.add_trace(go.Scatter(x = tau_vals,
                                     y = [Epos + EQ + Ipos + IQ for (Epos, EQ, Ipos, IQ) in zip(self.E("pos", sol_y), self.E("Q", sol_y), self.I("pos", sol_y), self.I("Q", sol_y))],
                                     name = f"P(Isolated) with Npool = {Npool}",
                                     line=dict(color=f'rgba({255}, {255 * (1 - colours[Npool_vals.index(Npool)])}, 0, 1)')
                                    )
                         )
            
        fig.update_layout(xaxis_title = r'$\tau$')
            
        if save_fig == True:
            
            fig.write_image("Images/plot_probability_isolated_given_age_of_infection_varying_Npool.png", scale = 4)
            
            
        fig.show()

    def code_check_DFE_probabilities_sum_one(self, param_obj):
        
        #A person can either be in S or Sneg when infected
        
        if abs(self.prob_infected_in('S', param_obj) + self.prob_infected_in('Spneg', param_obj) - 1) > 0.0000001:
            
            print ("The probability of being infected in S added to the probability of being infected in Sneg doesn't add to 1")
        
    def code_check_probabilities_not_return_equal_recovered_before_return(self, param_obj):

        #2. The probability of not returning to a compartment in the latent or infectious phase equals the probability
        # that you recovered before you returned      
        
        prob_dont_return_to_E_1 = 1 - self.prob_return('E', param_obj)
        prob_dont_return_to_E_2 = self.prob_step('E', 'I', param_obj) + self.prob_chain(['E', 'Epneg', 'Ipneg'], param_obj)

        prob_dont_return_to_Epneg_1 = 1 - self.prob_return('Epneg', param_obj) 
        prob_dont_return_to_Epneg_2 = self.prob_step('Epneg', 'Ipneg', param_obj) + self.prob_chain(['Epneg', 'E', 'I'], param_obj)
        
        prob_dont_return_to_I_1 =  1 - self.prob_return('I', param_obj) 
        prob_dont_return_to_I_2 = self.prob_step('I', 'R', param_obj) + self.prob_chain(['I', 'Ippos', 'Rppos'], param_obj) + self.prob_chain(['I', 'Ippos', 'Ipos', 'Rpos'], param_obj) + self.prob_chain(['I', 'Ippos', 'Ipos', 'IQ', 'RQ'], param_obj)
        
        prob_dont_return_to_Ippos_1 = 1 - self.prob_return('Ippos', param_obj) 
        prob_dont_return_to_Ippos_2 = self.prob_step('Ippos', 'Rppos', param_obj) + self.prob_chain(['Ippos', 'Ipos', 'Rpos'], param_obj) + self.prob_chain(['Ippos', 'Ipos', 'IQ', 'RQ'], param_obj) + self.prob_chain(['Ippos', 'Ipos', 'IQ', 'I', 'R'], param_obj)
        
        prob_dont_return_to_Ipneg_1 = 1 - self.prob_return('Ipneg', param_obj)
        prob_dont_return_to_Ipneg_2 = self.prob_step('Ipneg', 'Rpneg', param_obj) + self.prob_step('Ipneg', 'I', param_obj) 
        
        prob_dont_return_to_Ipos_1 = 1 - self.prob_return('Ipos', param_obj)
        prob_dont_return_to_Ipos_2 = self.prob_step('Ipos', 'Rpos', param_obj) + self.prob_chain(['Ipos', 'IQ', 'RQ'], param_obj) + self.prob_chain(['Ipos', 'IQ', 'I', 'R'], param_obj) + self.prob_chain(['Ipos', 'IQ', 'I', 'Ippos', 'Rppos'], param_obj)
        
        prob_dont_return_to_IQ_1 = 1 - self.prob_return('IQ', param_obj) 
        prob_dont_return_to_IQ_2 = self.prob_step('IQ', 'RQ', param_obj) + self.prob_chain(['IQ', 'I', 'R'], param_obj) + self.prob_chain(['IQ', 'I', 'Ippos', 'Rppos'], param_obj) + self.prob_chain(['IQ', 'I', 'Ippos', 'Ipos', 'Rpos'], param_obj)
  
        if abs(prob_dont_return_to_E_1 - prob_dont_return_to_E_2) > 1e-15:
            
            print('The probability of not returning to E isn\'t correct')
            print(prob_dont_return_to_E_1)
            print(prob_dont_return_to_E_2)
        
        if abs(prob_dont_return_to_Epneg_1 - prob_dont_return_to_Epneg_2) > 1e-15:
            
            print('The probability of not returning to Epneg isn\'t correct')   
            print(prob_dont_return_to_Epneg_1)
            print(prob_dont_return_to_Epneg_2)
            
        if abs(prob_dont_return_to_I_1 - prob_dont_return_to_I_2) > 1e-15:
            
            print('The probability of not returning to I isn\'t correct')
            print(str(prob_dont_return_to_I_1))
            print(str(prob_dont_return_to_I_2))
        
        if abs(prob_dont_return_to_Ippos_1 - prob_dont_return_to_Ippos_2) > 1e-15:
            
            print('The probability of not returning to Ippos isn\'t correct') 
            
        if abs(prob_dont_return_to_Ipneg_1 - prob_dont_return_to_Ipneg_2 ) > 1e-15:
            
            print('The probability of not returning to Ipneg isn\'t correct')  
            
        if abs(prob_dont_return_to_Ipos_1 - prob_dont_return_to_Ipos_2) > 1e-15:
            
            print('The probability of not returning to Ipos isn\'t correct') 
            print(str(prob_dont_return_to_Ipos_1))
            print(str(prob_dont_return_to_Ipos_2))
        
        if abs(prob_dont_return_to_IQ_1 - prob_dont_return_to_IQ_2) > 1e-15:
            
            print('The probability of not returning to IQ isn\'t correct') 
            print(str(prob_dont_return_to_IQ_1))
            print(str(prob_dont_return_to_IQ_2))        
 
    def code_check_return_probabilities_equal(self, param_obj):
        
        #3. Check the probabilities of return all equal
        
        if self.prob_return('E', param_obj) == self.prob_return('Epneg', param_obj) == False:
            
            print('Probabilities of return don\'t equal for latent compartments')
            print(str(self.prob_return('E', param_obj)))
            print(str(self.prob_return('Epneg', param_obj)))
            
        if self.prob_return('I', param_obj) == self.prob_return('Ippos', param_obj) == self.prob_return('Ipos', param_obj) == self.prob_return('IQ', param_obj) == False:
            
            print('Probabilities of return don\'t equal for infectious compartments')          
           
    def code_check_step_probabilities_sum_one(self, param_obj):
 
        #4. Check all the step probabilities add to one
        
        compartment_from_list = ['E', 'Epneg', 'I', 'Ippos', 'Ipos', 'IQ']
        compartment_to_list = ['E', 'Epneg', 'I', 'Ippos', 'Ipneg', 'Ipos', 'IQ', 'R', 'Rppos', 'Rpos', 'RQ']  
        
        for compartment_from in compartment_from_list:
            
            summed_probability = 0
            
            for compartment_to in compartment_to_list:
                
                summed_probability += self.prob_step(compartment_from, compartment_to, param_obj)
                
            if abs(summed_probability - 1) > 1e-15:
                
                print(compartment_from + ' has an issue with it\'s step probabilty')

    def code_check_advancement_probabilities_sum_one(self, param_obj):
        
        #5. Check that the advancement probabilities add to 1
        
        prob_advance_all_paths_E = self.prob_advance('E', param_obj) + self.prob_step('E', 'Epneg', param_obj) * self.prob_advance('Epneg', param_obj)
        prob_advance_all_paths_Epneg = self.prob_advance('Epneg', param_obj) + self.prob_step('Epneg', 'E', param_obj) * self.prob_advance('E', param_obj)
        
        prob_advance_all_paths_I = self.prob_advance('I', param_obj) + self.prob_step('I', 'Ippos', param_obj) * self.prob_advance('Ippos', param_obj) + self.prob_chain(['I', 'Ippos', 'Ipos'], param_obj) * self.prob_advance('Ipos', param_obj) + self.prob_chain(['I', 'Ippos', 'Ipos', 'IQ'], param_obj) * self.prob_advance('IQ', param_obj)
        prob_advance_all_paths_Ippos = self.prob_advance('Ippos', param_obj) + self.prob_step('Ippos', 'Ipos', param_obj) * self.prob_advance('Ipos', param_obj) + self.prob_chain(['Ippos', 'Ipos', 'IQ'], param_obj) * self.prob_advance('IQ', param_obj) + self.prob_chain(['Ippos', 'Ipos', 'IQ', 'I'], param_obj) * self.prob_advance('I', param_obj)
        prob_advance_all_paths_Ipneg = self.prob_advance('Ipneg', param_obj) + self.prob_step('Ipneg', 'I', param_obj) * self.prob_advance('I', param_obj) + self.prob_chain(['Ipneg', 'I', 'Ippos'], param_obj) * self.prob_advance('Ippos', param_obj) + self.prob_chain(['Ipneg', 'I', 'Ippos', 'Ipos'], param_obj) * self.prob_advance('Ipos', param_obj) + self.prob_chain(['Ipneg', 'I', 'Ippos', 'Ipos', 'IQ'], param_obj) * self.prob_advance('IQ', param_obj)
        prob_advance_all_paths_Ipos = self.prob_advance('Ipos', param_obj) + self.prob_step('Ipos', 'IQ', param_obj) * self.prob_advance('IQ', param_obj) + self.prob_chain(['Ipos', 'IQ', 'I'], param_obj) * self.prob_advance('I', param_obj) + self.prob_chain(['Ipos', 'IQ', 'I', 'Ippos'], param_obj) * self.prob_advance('Ippos', param_obj)
        prob_advance_all_paths_IQ = self.prob_advance('IQ', param_obj) + self.prob_step('IQ', 'I', param_obj) * self.prob_advance('I', param_obj) + self.prob_chain(['IQ', 'I', 'Ippos'], param_obj) * self.prob_advance('Ippos', param_obj) + self.prob_chain(['IQ', 'I', 'Ippos', 'Ipos'], param_obj) * self.prob_advance('Ipos', param_obj)
       
        if abs(prob_advance_all_paths_E - 1) > 1e-15:
            
            print('The advancement probabilities from the point of view of an individual in E don\'t sum to 1')
            print(str(prob_advance_all_paths_E))
            
        if abs(prob_advance_all_paths_Epneg - 1) > 1e-15:
            
            print('The advancement probabilities from the point of view of an individual in Epneg don\'t sum to 1')
            print(str(prob_advance_all_paths_Epneg))     

            
        if abs(prob_advance_all_paths_I - 1) > 1e-15:
            
            print('The advancement probabilities from the point of view of an individual in I don\'t sum to 1')
            print(str(prob_advance_all_paths_I))
            
        if abs(prob_advance_all_paths_Ippos - 1) > 1e-15:
            
            print('The advancement probabilities from the point of view of an individual in Ippos don\'t sum to 1')
            print(str(prob_advance_all_paths_Ippos))  
            
        if abs(prob_advance_all_paths_Ipneg - 1) > 1e-15:
            
            print('The advancement probabilities from the point of view of an individual in Ipneg don\'t sum to 1')
            print(str(prob_advance_all_paths_Ipneg))            
            
        if abs(prob_advance_all_paths_Ipos - 1) > 1e-15:
            
            print('The advancement probabilities from the point of view of an individual in Ipos don\'t sum to 1')
            print(str(prob_advance_all_paths_Ipos))     
            
        if abs(prob_advance_all_paths_IQ - 1) > 1e-15:
            
            print('The advancement probabilities from the point of view of an individual in IQ don\'t sum to 1')
            print(str(prob_advance_all_paths_IQ))           
 
    def code_check_conservation_of_population_principle(self, param_obj):
        
        #6: Check that given some random distribution of the population among the different compartments
        #   that vector field satisfies the conservation of population principle, that is sums to one.
        '''
        for i in range(20):
            
            y = np.random.rand(20)
            y /= y.sum()
            
            sum_of_derivatives = sum(self._fun(0, y))
        
            if abs(sum_of_derivatives) > 0.000000001:
                
                print(f"Sum of vector field componants equals {sum_of_derivatives} instead of zero")
                
        '''
        
    def code_check_infectivity_curve_no_testing(self, param_obj):

        # Check that the transmission function generated matches that of the SEIR model with no testing
        
        tau_vals = np.linspace(0, 3 * (param_obj.t_E + param_obj.t_I), 1000)
        
        traditional_infectivity_curve = [param_obj.R0 *  param_obj.sigma *  param_obj.gamma / ( param_obj.sigma -  param_obj.gamma) * (np.exp(- param_obj.gamma * tau) - np.exp(- param_obj.sigma * tau)) for tau in tau_vals]
        
        infectivity_curve = self.generate_infectivity_curve(tau_vals, param_obj.change_attributes_and_return_new_obj([['r_v', 0]]))
        
        fig = go.Figure(go.Scatter(x = tau_vals,
                                   y = traditional_infectivity_curve,
                                   name = "Traditional"
                                  )
                       )
        
        fig.add_trace(go.Scatter(x = tau_vals,
                                 y = infectivity_curve,
                                 name = "Npool = 0",
                                 line = dict(dash = 'dash')
                                )
                     )
        
        fig.update_layout(title = "Check infectivity curves match given Npool = 0",
                          xaxis_title = r"$\tau$")
        
        fig.show()    
    
    def run_code_checks(self, param_obj):  
        
        self.code_check_DFE_probabilities_sum_one(param_obj)
        self.code_check_probabilities_not_return_equal_recovered_before_return(param_obj)
        self.code_check_return_probabilities_equal(param_obj)
        self.code_check_step_probabilities_sum_one(param_obj)
        self.code_check_advancement_probabilities_sum_one(param_obj)
        self.code_check_conservation_of_population_principle(param_obj)
        self.code_check_infectivity_curve_no_testing(param_obj)
                
params_0 = Parameters_SEIR_pool_testing(R0 = 1.5, 
                                        Npool = 5, 
                                        r_v = 0.05, 
                                        t_E = 3, 
                                        t_I = 1, 
                                        t_P = 1, 
                                        t_S = 1, 
                                        t_Q = 7
                                        )

model_0 = SEIR_pool_testing()
model_0.run_code_checks(params_0)

sol_0 = model_0.generate_solution(np.linspace(0, 100, 1000), params_0)

In [2]:
tau_vals = np.linspace(0, 3 * (params_0.t_E + params_0.t_I), 1000)
Npool_vals = [2, 5, 10, 30]

model_0.plot_probability_isolated_given_age_of_infection_varying_Npool(tau_vals, Npool_vals, params_0, save_fig = False)


In [3]:
tau_vals = np.linspace(0, 3 * (params_0.t_E + params_0.t_I), 1000)
Npool_vals = [2, 5, 10, 30]

model_0.plot_infectivity_curves_varying_Npool(tau_vals, Npool_vals, params_0, save_fig = False)

In [6]:
# Interact plot of the pool sensitivity over time with the proportion of the infected population that is E or I


def interact_plot_sensitivity_proportion_infected_populations(R0, Npool, r_v, t_E, t_I, t_P, t_S, t_Q):
    
    t_linspace = np.linspace(0, 300, 10 * 100)    
    beta = R0 / t_I
    
    
    params_interact = Parameters_SEIR_pool_testing(R0, Npool, r_v, t_E, t_I, t_P, t_S, t_Q)
    SEIR_interact = SEIR_pool_testing()
    sol = SEIR_interact.generate_solution(t_linspace, params_interact)

    fig, axes = plt.subplots(1, 2)

    axes[0].plot(t_linspace, SEIR_interact.sensitivity(Npool, sol.y))
    axes[0].plot(t_linspace, SEIR_interact.I('total', sol.y))
    axes[0].set_title("Pool sensitivity")
    axes[0].set_xlabel("t")
    axes[0].set_ylim((0, 1))
    
    axes[1].plot(t_linspace, SEIR_interact.E_proportion_infected_population(sol.y), label = "E", color = "blue")
    
    axes[1].scatter(5, SEIR_interact.E_proportion_infected_population_limiting_SEIR_traditional(params_interact),  color = "blue")
    
    axes[1].plot(t_linspace, SEIR_interact.I_proportion_infected_population(sol.y), label = "I", color = "red")
    
    axes[1].scatter(5, SEIR_interact.I_proportion_infected_population_limiting_SEIR_traditional(params_interact),  color = "red")
    
    axes[1].set_title("% of E + I that is E and I")
    axes[1].legend()
    axes[1].set_xlabel("t")
    axes[1].set_ylim((0, 1))
    
    plt.tight_layout()
    plt.show()
    
interact(interact_plot_sensitivity_proportion_infected_populations,
         R0 = (1.1, 3, 0.1), 
         Npool = (2, 20, 1),
         r_v = (0.001, 0.01, 0.001),
         t_E = (0.5, 3, 0.5),
         t_I = (0.5, 3, 0.5),
         t_P = (0.5, 3, 0.5),
         t_S = (0.5, 3, 0.5),
         t_Q = (8, 14, 1)
        )

sensititivy_params = Parameters_SEIR_pool_testing(R0 = 2,
                                                  Npool = 10,
                                                  r_v = 0.01,
                                                  t_E = 3,
                                                  t_I = 1,
                                                  t_P = 1, 
                                                  t_S = 1,
                                                  t_Q = 7
                                                 )

interactive(children=(FloatSlider(value=2.0, description='R0', max=3.0, min=1.1), IntSlider(value=11, descript…

In [7]:
# Interactive plot of the pool reproductive number

def interactive_plot_pool_R0(t_P, t_S, t_Q, Npool, t_tot, latent_time_ratio):
    
    Omega, omega = 1 / t_P, 1 / t_S
    t_E =  (t_tot * latent_time_ratio)
    t_I =  (t_tot * (1 - latent_time_ratio))
    beta = R0 * gamma
    delta = 1 / t_Q
    
    r_v_max = 1 / t_P - 0.01
    r_v_values = np.linspace(0, r_v_max, 1000)
    param_objs = [Parameters_SEIR_pool_testing(R0, Npool, r_v, t_E, t_I, t_P, t_S, t_Q, t_linspace) for r_v in r_v_values]
      
    pool_R0_values = [SEIR_p0.generate_R0_p(param_obj)for param_obj in param_objs]
    
    fig, axes = plt.subplots()
    
    axes.plot(r_v_values, pool_R0_values, label = r'$R_0^p$')
    axes.set_xlim(0.0001, 0.2)
    axes.set_ylim((1, R0+1))
    axes.set_title(r"$\mathbb{R}_0$")
    axes.set_xlabel("Tests per unit time")
    axes.legend()
    
    plt.show()
    
interact(model_0.interactive_plot_pool_R0, t_P = (1/2, 10, 1/2), t_S = (1/2, 10, 1/2), t_Q = (1, 12, 1), Npool = (2, 50, 1), t_tot = (1, 12, 1), latent_time_ratio = (0.1, 0.9, 0.1), other_params_obj = fixed(params_0))

interactive(children=(FloatSlider(value=5.0, description='t_P', max=10.0, min=0.5, step=0.5), FloatSlider(valu…

<function ipywidgets.widgets.interaction._InteractFactory.__call__.<locals>.<lambda>(*args, **kwargs)>

In [8]:
# Plot showing the effect of increasing the number of tests performed on the pool reproductive number

Npool_vals = [2, 5, 10, 20, 50]

model_0.plot_R0_vs_r_v_varying_Npool(Npool_vals, params_0)

In [9]:
# Plot showing the effect of pool turn-about time on the pool reproductive number

t_P_vals = np.linspace(0.1, 10, 1000)
Npool_vals = [2, 5, 10, 20, 50]

model_0.plot_R0_vs_t_P_varying_Npool(t_P_vals, Npool_vals, params_0, save_fig = False)

In [10]:
# Plot showing the effect of standard test turn-about time on the pool reproductive number

t_S_vals = np.linspace(0.1, 10, 1000)
Npool_vals = [2, 5, 10, 20, 50]

model_0.plot_R0_vs_t_S_varying_Npool(t_S_vals, Npool_vals, params_0, save_fig = False)

In [11]:
# Plot showing the effect of quarantine time on the pool reproductive number

t_Q_vals = np.linspace(0.1, 10, 1000)
Npool_vals = [2, 5, 10, 20, 50]

model_0.plot_R0_vs_t_Q_varying_Npool(t_Q_vals, Npool_vals, params_0, save_fig = False)

In [12]:
# Plot showing the effect of increasing the total infected time on the pool reproductive number

Npool_vals = [2, 5, 10, 20, 50]

model_0.plot_R0_vs_latent_proportion_varying_Npool(Npool_vals, params_0, save_fig = False)

In [13]:
# Plot showing the effect of increasing the prportion of total infected time spent latent on the pool reproductive number

t_tot_vals = np.linspace(0.1, 8, 1000)
Npool_vals = [2, 5, 10, 20, 50]

model_0.plot_R0_vs_t_tot_varying_Npool(t_tot_vals, Npool_vals, params_0)