In [None]:
import numpy as np
import os
from matplotlib import pyplot as plt

# settings
plt.rcParams["font.size"] = 18
solid = dict(c='#235b8c', ls='-', lw=1.2, label_kwargs=dict(color='#d99d36'))
dotted = dict(c='#737373', ls='-', lw=1.0, alpha=0.3, label_kwargs=dict(color='#737373'))
depth = 0.3

In [None]:
def draw_cube(ax, xy, size, depth=0.4, edges=None, label=None, label_kwargs=None, **kwargs):
    if edges is None:
        edges = list(range(1, 13))
    if label_kwargs is None:
        label_kwargs = {}        

    x, y = xy

    if 1 in edges:
        ax.plot([x, x + size], [y + size, y + size], **kwargs)
    if 2 in edges:
        ax.plot([x + size, x + size], [y, y + size], **kwargs)
    if 3 in edges:
        ax.plot([x, x + size], [y, y], **kwargs)
    if 4 in edges:
        ax.plot([x, x], [y, y + size], **kwargs)

    if 5 in edges:
        ax.plot([x, x + depth], [y + size, y + depth + size], **kwargs)
    if 6 in edges:
        ax.plot([x + size, x + size + depth], [y + size, y + depth + size], **kwargs)
    if 7 in edges:
        ax.plot([x + size, x + size + depth], [y, y + depth], **kwargs)
    if 8 in edges:
        ax.plot([x, x + depth], [y, y + depth], **kwargs)

    if 9 in edges:
        ax.plot([x + depth, x + depth + size], [y + depth + size, y + depth + size], **kwargs)
    if 10 in edges:
        ax.plot([x + depth + size, x + depth + size], [y + depth, y + depth + size], **kwargs)
    if 11 in edges:
        ax.plot([x + depth, x + depth + size], [y + depth, y + depth], **kwargs)
    if 12 in edges:
        ax.plot([x + depth, x + depth], [y + depth, y + depth + size], **kwargs)

    if label:
        ax.text(x + 0.5 * size, y + 0.5 * size, label, ha='center', va='center', **label_kwargs)

In [None]:
fig = plt.figure(figsize=(10, 2), facecolor='w')
ax = plt.axes([0, 0, 1, 1], xticks=[], yticks=[], frameon=False)

draw_cube(ax=ax, xy=(1, 1), size=1, depth=depth, edges=[1, 2, 3, 4, 5, 6, 9], label='16', **solid)
draw_cube(ax=ax, xy=(2, 1), size=1, depth=depth, edges=[1, 2, 3, 6, 9], label='9', **solid)
draw_cube(ax=ax, xy=(3, 1), size=1, depth=depth, edges=[1, 2, 3, 6, 7, 9, 10], label='80', **solid)

ax.text(4.6, 1.6, '+', size=16, ha='center', va='center')

draw_cube(ax=ax, xy=(5, 1), size=1, depth=depth, edges=[1, 2, 3, 4, 5, 6, 7, 9, 10], label='8', **solid)

ax.text(1.5, 0.2, r'${\tt np.array([16, 9, 80]) + 8}$', size=14, ha='left', va='bottom')

ax.set_xlim(0, 15)
ax.set_ylim(0, 3)

fig.savefig(os.path.join(os.pardir, "COMP_002_000_broadcasting_vector_scalar.png"))
plt.show()

In [None]:
fig = plt.figure(figsize=(10, 2), facecolor='w')
ax = plt.axes([0, 0, 1, 1], xticks=[], yticks=[], frameon=False)

draw_cube(ax=ax, xy=(1, 1), size=1, depth=depth, edges=[1, 2, 3, 4, 5, 6, 9], label='16', **solid)
draw_cube(ax=ax, xy=(2, 1), size=1, depth=depth, edges=[1, 2, 3, 6, 9], label='9', **solid)
draw_cube(ax=ax, xy=(3, 1), size=1, depth=depth, edges=[1, 2, 3, 6, 7, 9, 10], label='80', **solid)

ax.text(4.6, 1.6, '+', size=16, ha='center', va='center')

draw_cube(ax=ax, xy=(5, 1), size=1, depth=depth, edges=[1, 2, 3, 4, 5, 6, 7, 9, 10], label='8', **solid)
draw_cube(ax=ax, xy=(6, 1), size=1, depth=depth, edges=[1, 2, 3, 6, 7, 9, 10, 11], label='8', **dotted)
draw_cube(ax=ax, xy=(7, 1), size=1, depth=depth, edges=[1, 2, 3, 6, 7, 9, 10, 11], label='8', **dotted)

ax.text(8.7, 1.6, '=', size=16, ha='center', va='center')

draw_cube(ax=ax, xy=(9, 1), size=1, depth=depth, edges=[1, 2, 3, 4, 5, 6, 9], label='24', **solid)
draw_cube(ax=ax, xy=(10, 1), size=1, depth=depth, edges=[1, 2, 3, 6, 9], label='17', **solid)
draw_cube(ax=ax, xy=(11, 1), size=1, depth=depth, edges=[1, 2, 3, 6, 7, 9, 10], label='88', **solid)

ax.text(1.5, 0.2, r'${\tt np.array([16, 9, 80]) + 8}$', size=14, ha='left', va='bottom')

ax.set_xlim(0, 15)
ax.set_ylim(0, 3)

fig.savefig(os.path.join(os.pardir, "COMP_002_001_broadcasting_vector_scalar.png"))
plt.show()