In [None]:
import numpy as np
from scipy.optimize import fsolve
import matplotlib.pyplot as plt
import warnings
from ipywidgets import interact, interact_manual
from ipywidgets import widgets, Checkbox, fixed
from utils import riemann_tools

In [None]:
def p(volume):
    return -np.exp(volume)

def dpdv(volume):
    return -np.exp(volume)

def lamda_1(v):
    return -np.sqrt(-dpdv(v))
    
def lamda_2(v):
    return  np.sqrt(-dpdv(v))
    
def exact_riemann_solution(q_l, q_r):
    v_l, u_l = q_l
    v_r, u_r = q_r
    
    integral_curve_1 = lambda v: u_l - 2*(np.exp(v_l/2.)-np.exp(v/2.))
    integral_curve_2 = lambda v: u_r + 2*(np.exp(v_r/2.)-np.exp(v/2.))
    hugoniot_locus_1 = lambda v: u_l + np.sqrt(-(p(v)-p(v_l))*(v-v_l))
    hugoniot_locus_2 = lambda v: u_r - np.sqrt(-(p(v)-p(v_r))*(v-v_r))
    
    def phi_l(v):
        if lamda_1(v) < lamda_1(v_l):
            return hugoniot_locus_1(v)
        else:
            return integral_curve_1(v)
        
    def phi_r(v):
        if lamda_2(v) > lamda_2(v_r):
            return hugoniot_locus_2(v)
        else:
            return integral_curve_2(v)
        
    phi = lambda v: phi_l(v) - phi_r(v)
    guess = (v_l + v_r)/2.
    
    v_m, _, ier, msg = fsolve(phi, guess, full_output=True, xtol=1.e-14)
    if ier != 1:
        print('Warning: fsolve did not converge.')
        print(msg)
        
    v_m = v_m[0]
    u_m = phi_l(v_m)
    #return v_m, u_m

    ws = np.zeros(4)
    wave_types = ['', '']
    
    if (lamda_1(v_l) > lamda_1(v_m)): # 1-shock
        wave_types[0] = 'shock'
        ws[0] = -(u_l-u_m)/(v_l-v_m)
        ws[1] = ws[0]
    else:
        wave_types[0] = 'raref'
        ws[0] = lamda_1(v_l)
        ws[1] = lamda_1(v_m)
    
    if (lamda_2(v_m) > lamda_2(v_r)): # 2-shock
        wave_types[1] = 'shock'
        ws[2] = -(u_r-u_m)/(v_r-v_m)
        ws[3] = ws[2]
    else:
        wave_types[1] = 'raref'
        ws[2] = lamda_2(v_m)
        ws[3] = lamda_2(v_r)
        
    def raref1(xi):
        RiemannInvariant = u_l - 2*np.exp(v_l/2.)
        u = RiemannInvariant - 2*xi
        v = 2*np.log(np.abs(xi))
        return v, u
    
    def raref2(xi):
        RiemannInvariant = u_r + 2*np.exp(v_r/2.)
        u = RiemannInvariant - 2*xi
        v = 2*np.log(np.abs(xi))
        return v, u
    
    q_m = np.array((v_m, u_m))
    
    states = np.column_stack([q_l, q_m, q_r])
    speeds = [[], []]
    if wave_types[0] is 'shock':
        speeds[0] = ws[0]
    else:
        speeds[0] = (ws[0],ws[1])
    if wave_types[1] is 'shock':
        speeds[1] = ws[2]
    else:
        speeds[1] = (ws[2],ws[3])

    def reval(xi):
        rar1 = raref1(xi)
        rar2 = raref2(xi)
        v_out = (xi<=ws[0])*v_l + \
                (xi>ws[0])*(xi<=ws[1])*rar1[0] + \
                (xi>ws[1])*(xi<=ws[2])*v_m + \
                (xi>ws[2])*(xi<=ws[3])*rar2[0] + \
                (xi>ws[3])*v_r
       
        u_out = (xi<=ws[0])*u_l + \
                (xi>ws[0])*(xi<=ws[1])*rar1[1] + \
                (xi>ws[1])*(xi<=ws[2])*u_m + \
                (xi>ws[2])*(xi<=ws[3])*rar2[1] + \
                (xi>ws[3])*u_r
        return v_out, u_out
    return states, speeds, reval, wave_types

In [None]:
q_l = np.array([1,1.])
q_r = np.array([3,4.])

states, speeds, reval, wave_types = exact_riemann_solution(q_l, q_r)

In [None]:
print(states)

In [None]:
print(speeds)

In [None]:
print(wave_types)

In [None]:
riemann_tools.plot_riemann(states,speeds,reval,wave_types,t=0.3);

In [None]:
def lambda1(q,xi):
    return lamda_1(q[0])

def lambda2(q,xi):
    return lamda_2(q[0])

def make_plot_function(q_l, q_r):
    states, speeds, reval, wave_types = \
        exact_riemann_solution(q_l,q_r)
        
    def plot_function(t,which_char):
        ax = riemann_tools.plot_riemann(states,speeds,reval,wave_types,
                                        t=t,t_pointer=0,
                                        extra_axes=False,
                                        variable_names=['V','U'])
        if which_char == 1:
            riemann_tools.plot_characteristics(reval,lambda1,None,ax[0])
        elif which_char == 2:
            riemann_tools.plot_characteristics(reval,lambda2,None,ax[0])
        #nonlinear_elasticity.phase_plane_plot(q_l, q_r, aux_l, aux_r, ax[3])
        plt.show()
    return plot_function        
        
def plot_riemann_solution(q_l,q_r):
    plot_function = make_plot_function(q_l,q_r)
    interact(plot_function, t=widgets.FloatSlider(value=0.1,min=0,max=1.,step=0.1),
             which_char=widgets.Dropdown(options=[None,1,2],
                                         description='Show characteristics'));

In [None]:
plot_riemann_solution(q_l,q_r)

In [None]:
def full_riemann(v_l=2., v_r=1., u_l=1., u_r=1.):
    q_l = np.array((v_l,u_l))
    q_r = np.array((v_r,u_r))
    plot_riemann_solution(q_l,q_r)

In [None]:
interact_manual(full_riemann);