In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

from sympy import symbols, solve, init_printing, Eq, lambdify

%matplotlib inline
init_printing()

In [None]:
mu, lam = symbols('mu lambda', real=True)
n = symbols('n', real=True)

tcn = 1/(n**2 - 1) * (mu*n**2*(n**2-1)**2/(1 + mu*n**2) + lam)

In [None]:
mmu = np.logspace(-3, 0, 100)

fig, ax = plt.subplots(figsize=(6.5*2/3, 6.5*2/3))

for m in range(5):
    lam_nm = lambdify(mu, solve(Eq(tcn.subs(n, m+2) - tcn.subs(n, m+3)), lam)[0])
    ax.loglog(lam_nm(mmu), mmu, 'k')

ax.set_xlim([1., 300.])
ax.set_ylim([.003, 1.0])

In [None]:
n_l, n_mu = (100, 100)
ll, mm = np.meshgrid(np.logspace(0., 2., n_l), np.logspace(0, -2, n_mu))

# Convert t_cn to a callable function which operates element-wise on inputs
tcn_fxn = np.vectorize(lambdify((lam, mu, n), tcn))

tc = np.amin(np.concatenate((tcn_fxn(ll, mm, 2).reshape((n_l, n_mu, 1)),
                             tcn_fxn(ll, mm, 3).reshape((n_l, n_mu, 1)),
                             tcn_fxn(ll, mm, 4).reshape((n_l, n_mu, 1))), axis=2), axis=2)

fig, ax = plt.subplots(figsize=(6.5*2/3, 6.5*2/3))

ax.imshow(tc, interpolation='gaussian', cmap='viridis')