In [2]:
from manim import *
from manim.utils.color import Colors
# from manim_fontawesome import *
from manim_fonts import *
from random import seed, shuffle
import numpy as np
import scipy as sp

%load_ext autoreload
%autoreload 2

from nextgen_01_defs import *
from nextgen_01_lib import *


# 00: Welcome

In [2]:
%%manim -v WARNING --progress_bar None NG_01_00_Welcome

seed(14)

class NG_01_00_Welcome(Scene):
    def construct(self):
        with RegisterFont("Montserrat") as fonts:
            logo = SVGMobject('./CSIRO-Data61-logo.svg').scale(1.5)
            # logo = SVGMobject('./Data61-logo.svg')
            # logo = SVGMobject('./CSIRO-logo.svg')
            self.play(Write(logo))
            self.wait()
            logo.generate_target()
            logo.target.scale(0.30).to_edge(DR)
            self.play(MoveToTarget(
                logo))
            self.wait()
            lesson_text = Text('NextGen Lecture 01', font=fonts[0])
            self.play(Write(lesson_text))
            self.wait()
            lesson_text2 = Text('Beginning probability', font=fonts[0])
            self.play(Transform(lesson_text, lesson_text2))
            self.wait()
            self.play(FadeOut(lesson_text2))
            dan = ImageMobject('dan_head.webp')
            lesson_text3 = Text('with Dan MacKinlay', font=fonts[0]).align_to(dan, DOWN)
            self.play(FadeIn(dan))
            self.play(Write(lesson_text3))
            self.wait()
            self.wait()


# 01: It takes a village of 100 people

In [6]:
%%manim -v WARNING --progress_bar None NG_01_01_ProbDoctor
SEED = 15

class NG_01_01_ProbDoctor(Scene):

    def construct(self):
        self.n_well_pos = 9
        self.n_well_neg = 81
        self.n_sick_pos = 9
        self.n_sick_neg = 1
        self.n_pop = (
            self.n_well_pos +
            self.n_well_neg +
            self.n_sick_pos +
            self.n_sick_neg
        )
        seed(SEED)
        rng = np.random.default_rng(SEED)

        scale = 0.2
        wait_scale = 0.01

        well_pos = [patient_state_icon(WELL_SYMB, NEG_COLOR) for i in range(self.n_well_pos)]
        well_neg = [patient_state_icon(WELL_SYMB, NEG_COLOR) for i in range(self.n_well_neg)]
        sick_pos = [patient_state_icon(WELL_SYMB, NEG_COLOR) for i in range(self.n_sick_pos)]
        sick_neg = [patient_state_icon(WELL_SYMB, NEG_COLOR) for i in range(self.n_sick_neg)]
        pop = well_pos + well_neg + sick_pos + sick_neg
        ordered_pop = list(pop)
        shuffle(pop)
        pop_group = Group(*pop)
        pop_group.arrange_in_grid(n_cols=20)
        pop_group.shift(LEFT*3.25)


        census_text = MathTex(
            r"""\text{\# pop.}&=100\\""",
            r"""\text{\# sick}&=10\\""",
            r"""\text{\# well}&=90\\""",
            r"""\text{\# well \& +ve}&=9\\""",
            r"""\text{\# sick \& +ve}&=9\\""",
            r"""\frac{\text{\# sick \& +ve}}{\text{\# +ve}} &= {{ \,?\, }}""",
            # tex_template=TexFontTemplates.gnu_freesans_tx,
            font_size=55)

        (
            pop_text,
            sick_text,
            well_text,
            well_pos_text,
            sick_pos_text,
            true_sick_text,
            true_sick_ans_text
        ) = census_text
        census_text.shift(RIGHT*3.25)
        # self.add(index_labels(census_text))

        well_group = VGroup(*well_pos, *well_neg)
        sick_group = VGroup(*sick_pos, *sick_neg)
        well_pos_group = VGroup(*well_pos)
        well_neg_group = VGroup(*well_neg)
        sick_pos_group = VGroup(*sick_pos)
        sick_neg_group = VGroup(*sick_neg)
        pos_group = VGroup(*well_pos, *sick_pos)
        neg_group = VGroup(*sick_neg, *well_neg)



        self.play(FadeIn(pop_group))
        self.wait(1.0*wait_scale)
        self.play(Write(pop_text))
        self.play(Indicate(pop_group))
        self.play(Write(sick_text))
        self.play(Indicate(sick_group))
        self.wait(1.0*wait_scale)
        self.play(Write(well_text))
        self.play(Indicate(well_group))
        self.wait(1.0*wait_scale)
        self.play(FadeToColor(pos_group, color=POS_COLOR))
        # self.play(Indicate(Group(pos_group, well_pos_text, sick_pos_text), color=POS_COLOR))
        self.wait(1.0*wait_scale)
        self.play(Write(well_pos_text))
        self.play(Indicate(well_pos_group, color=POS_COLOR))
        self.wait(1.0*wait_scale)
        self.play(Write(sick_pos_text))
        self.play(Indicate(sick_pos_group, color=NEG_COLOR))
        self.wait(1.0*wait_scale)
        self.play(Write(VGroup(true_sick_text, true_sick_ans_text)))
        self.wait(1.0*wait_scale)

        whole_pop_group = pop_group.copy()

        self.play(Unwrite(neg_group))
        self.wait(1.0*wait_scale)

        pos_group.generate_target()
        pos_group.target.arrange(LEFT).shift(LEFT*3).scale(0.5)
        self.play(MoveToTarget(
            pos_group))
        self.wait(1.0*wait_scale)
        pos_brace = Brace(pos_group)
        pos_brace_text = MathTex(
            r"""\text{\# +ve}&=18""").next_to(pos_brace, DOWN)
        self.play(Write(
            VGroup(pos_brace, pos_brace_text)
        ))
        self.wait(1.0*wait_scale)
        sick_pos_brace = Brace(sick_pos_group, UP)
        sick_pos_brace_text = MathTex(
            r"""\text{\# sick \& +ve}&=9""").next_to(sick_pos_brace, UP)
        self.play(Write(
            VGroup(sick_pos_brace, sick_pos_brace_text)
        ))
        answer_text = MathTex(
            r"""\frac{9}{18}""",
            font_size=55
        ).move_to(true_sick_ans_text.get_left(), LEFT )
        self.play(Transform(
            true_sick_ans_text,
            answer_text
        ))
        self.wait(1.0*wait_scale)
        better_answer_text = MathTex(
            r"""\frac{1}{2}""",
            font_size=55
        ).move_to(true_sick_ans_text.get_left(), LEFT)
        self.wait(1.0*wait_scale)
        self.play(Transform(
            true_sick_ans_text,
            better_answer_text
        ))
        self.wait(5.0*wait_scale)

        # census_text_prop = MathTex(
        #     # r"""\text{ pop.}&=100\\""",
        #     r"""\mathbb{P}(\text{sick})&=0.1\\""",
        #     r"""\mathbb{P}(\text{well})&=0.9\\""",
        #     r"""\mathbb{P}(\text{well \& +ve})&=0.09\\""",
        #     r"""\mathbb{P}(\text{sick \& +ve})&=0.09\\""",
        #     r"""p\left(\frac{\text{sick \& +ve}}{\text{+ve}}\right)&= {{ \,?\, }}""",
        #     font_size=55)
        # (
        #     # pop_text,
        #     sick_text_prop,
        #     well_text_prop,
        #     well_pos_text_prop,
        #     sick_pos_text_prop,
        #     true_sick_text_prop,
        #     true_sick_ans_text_prop
        # ) = census_text_prop
        # census_text_prop.shift(RIGHT*3.25)
        
        self.play(
            *[FadeOut(mob)for mob in self.mobjects]
        )
        self.play(FadeIn(whole_pop_group))

        # well_pos_group.generate_target()
        # well_pos_group.target.arrange_in_grid(n_cols=1)
        # well_pos_group.target.move_to(pop_group, UP+LEFT)
        # self.play(MoveToTarget(
        #     well_pos_group))
        # well_pos_block = VGroup(
        #     # Rectangle().set_color(POS_COLOR).match_height(well_pos_group).match_width(well_pos_group),
        #     Rectangle(width=9, height=9).set_color(POS_COLOR),
        #     well_icon.copy().move_to([0,0,0]).set_color(POS_COLOR).scale(0.4)
        # ).match_x(well_pos_group).match_y(well_pos_group)
        # self.play(Transform(
        #     well_pos_group,
        #     well_pos_block
        # ))


# 02: A model of the doctor office

In [7]:
%%manim -v WARNING --progress_bar None NG_01_02_ProbViaSim

SEED = 15

class NG_01_02_ProbViaSim(Scene):
    def construct(self):
        def create_person(marker=WELL_SYMB, color=NEG_COLOR):
            return VGroup(
                SVGMobject(
                    "noun-person-1492700.svg",
                    fill_color=color,
                    height=3,
                    width=1.5,
                ).set_z_index(0),
                patient_state_icon(
                    marker, BLACK,
                ).scale(0.75).set_z_index(0.5)
            )
        seed(SEED)
        rng = np.random.default_rng(SEED)

        DOOR_OFFSET = LEFT * 4
        
        corridor = Rectangle(
            height=4, width=4,
            color=BLACK,
            # stroke_color=WHITE,
            fill_color=BLACK,
            fill_opacity=1.0
        ).shift(DOOR_OFFSET+LEFT*3).set_z_index(1)

        door = Rectangle(height=4, width=2).set_z_index(3).shift(DOOR_OFFSET)

        self.add(corridor)
        self.play(
            FadeIn(door),
        )
        chart = BarChart(
            values=[0,0,0,0],
            bar_names=[WELL_SYMB, WELL_SYMB, SICK_SYMB, SICK_SYMB],
            bar_colors=[NEG_COLOR, POS_COLOR, NEG_COLOR, POS_COLOR],
            y_range=[0, 1.0, 0.1],
            y_length=6,
            x_length=6,
            # x_axis_config={"font_size": 36},
        )
        chart.shift(3*RIGHT)
        self.play(
            FadeIn(chart)
        )
        count = np.array([0, 0, 0, 0])
        for i in range(30):
            if i==0:
                p = np.array([0.95, 0.8])
            elif i==1:
                p = np.array([0.94, 0.95])
            elif i==2:
                p = np.array([0.04, 0.95])
            else:
                p = rng.random(2)
            if p[1]<0.9: # well
                symb = WELL_SYMB
                if p[0]>0.9: # pos
                    color = POS_COLOR
                    count[1] += 1
                else:
                    color = NEG_COLOR
                    count[0] += 1
            else:  # sick
                symb = SICK_SYMB
                if p[0]>0.1: # pos
                    color = POS_COLOR
                    count[3] += 1
                else:
                    color = NEG_COLOR
                    count[2] += 1
            person = create_person(
                symb, color
            ).shift(DOOR_OFFSET + LEFT*1.5)

            self.add(person)
            self.play(Succession(
                MoveAlongPath(
                    person,
                    Line(person.get_center(), RIGHT*3.5 + person.get_center())
                ),
                chart.animate.change_bar_values(count/count.sum()),
                FadeOut(person),
                run_time=3
            ))
        self.play(
            FadeOut(chart),
            FadeOut(door)
        )


# 03: Probability via simulation

In [8]:
%%manim -v WARNING --progress_bar None NG_01_03_ProbViaSim

SEED = 15

class NG_01_03_ProbViaSim(Scene):

    def construct(self):
        seed(SEED)
        rng = np.random.default_rng(SEED)

        scale = 0.2

        targetspace = Axes(
            y_range=[0, 1.0, 0.1],
            x_range=[0, 1.0, 0.1],
            y_length=4,
            x_length=4,
            # x_axis_config={"font_size": 36},
        )
        targetspace.shift(4*LEFT)
        self.add(targetspace)

        def as_rectangle_corners(bottom_left, top_right):
            return [
                (top_right[0], top_right[1]),
                (bottom_left[0], top_right[1]),
                (bottom_left[0], bottom_left[1]),
                (top_right[0], bottom_left[1]),
            ]
        
        def get_rectangle(bottom_left, top_right, color=POS_COLOR):
            polygon = Polygon(
                *[
                    targetspace.c2p(*i)
                    for i in as_rectangle_corners(
                        bottom_left, top_right
                    )
                ]
            )
            polygon.stroke_width = 1
            polygon.set_fill(color, opacity=0.3)
            polygon.set_stroke(color)
            return polygon

        def in_zone(zone, point):
            return (
                zone[0][0] <= point[0] < zone[1][0]
                and zone[0][1] <= point[1] < zone[1][1]
            )
        
        icon_scale = 0.6
        well_neg_zone = (0.0, 0.0), (0.9, 0.9)
        well_neg_rect = get_rectangle(*well_neg_zone, NEG_COLOR)
        self.add(well_neg_rect)
        self.add(patient_state_icon(WELL_SYMB, color=NEG_COLOR).scale(icon_scale).shift(well_neg_rect.get_center()))
        well_pos_zone = (0.9, 0.0), (1.0, 0.9)
        well_pos_rect = get_rectangle(*well_pos_zone, POS_COLOR)
        self.add(well_pos_rect)
        self.add(patient_state_icon(WELL_SYMB, color=POS_COLOR).scale(icon_scale).shift(well_pos_rect.get_center()))
        sick_neg_zone = (0.0, 0.9), (0.1, 1.0)
        sick_neg_rect = get_rectangle(*sick_neg_zone, NEG_COLOR)
        self.add(sick_neg_rect)
        self.add(patient_state_icon(SICK_SYMB, color=NEG_COLOR).scale(icon_scale).shift(sick_neg_rect.get_center()))
        sick_pos_zone = (0.1, 0.9), (1.0, 1.0)
        sick_pos_rect = get_rectangle(*sick_pos_zone, POS_COLOR)
        self.add(sick_pos_rect)
        self.add(patient_state_icon(SICK_SYMB, color=POS_COLOR).scale(icon_scale).shift(sick_pos_rect.get_center()))

        zones = [well_neg_zone, well_pos_zone, sick_neg_zone, sick_pos_zone]
        count = np.array([0, 0, 0, 0])

        def zone(point):
            for i, zone in enumerate(zones):
                if in_zone(zone, point):
                    return i, zone
            raise ValueError("Point not in any zone")

        chart = BarChart(
            values=[0,0,0,0],
            bar_names=[WELL_SYMB, WELL_SYMB, SICK_SYMB, SICK_SYMB],
            bar_colors=[NEG_COLOR, POS_COLOR, NEG_COLOR, POS_COLOR],
            y_range=[0, 1.0, 0.1],
            y_length=6,
            x_length=6,
            # x_axis_config={"font_size": 36},
        )
        chart.shift(3*RIGHT)
        self.play(
            FadeIn(chart)
        )

        ## various attempts to animate bar labels, all failed
        # c_bar_lbls = chart.get_bar_labels(font_size=24, 
        #     label_constructor=lambda f: DecimalNumber(float(f), num_decimal_places=2))    
        # for i, lab in enumerate(c_bar_lbls):
        #     # lab.add_updater(lambda l: l.set_value(chart.values[i]))
        #     lab.add_updater(lambda l: l.become(chart.bar_labels[i]))
        # self.add(chart, c_bar_lbls)
        # c_bar_lbls.add_updater(lambda l: l.become(chart.bar_labels))
        self.wait(1)
        for n_point_list_now, dur in [(1,1.0),] *10 + [(2*k , 1/k) for k in range(1, 25)]: 
            for step in range(n_point_list_now):
                # override first few point_list so the bars start out even-ish
                if count.sum() == 0:
                    p = np.array([0.45, 0.95])
                elif count.sum() == 1:
                    p = np.array([0.12, 0.8])
                elif count.sum() == 2:
                    p = np.array([0.03, 0.97])
                elif count.sum() == 3:
                    p = np.array([0.93, 0.2])
                elif count.sum() == 4:
                    p = np.array([0.03, 0.97])
                else:
                    p = rng.random(2)
                self.add(
                    Point(
                        targetspace.c2p(*p)
                    ))
                i, z = zone(p)
                count[i] += 1
                if count.sum()>13: break
                self.play(
                    chart.animate.change_bar_values(count/count.sum()),
                    # Transform(
                    #     c_bar_lbls,
                    #     chart.bar_labels),
                    run_time=dur)




# Continuous probability

In [3]:
%%manim -v WARNING --progress_bar None NG_01_04_ContinuousModels
SEED = 15

class NG_01_04_ContinuousModels(Scene):
    def construct(self):
        seed(SEED)
        rng = np.random.default_rng(SEED)

        N_POP = 100


        def create_person(width=1.5, height=2.6, *_, color=PURPLE):
            person = SVGMobject(
                "noun-person-1492700.svg",
                fill_color=color,
                width=width, height=height,
            ).set_z_index(0)
            # Actually the height is ignored so we need to stretch it
            person = person.stretch_to_fit_width(width).stretch_to_fit_height(height)
            return person

        targetspace = Axes(
            y_range=[0, 6, 1.0],
            x_range=[0, 5, 1.0],
            y_length=6,
            x_length=5,
            # x_axis_config={"font_size": 36},
        )
        
        targetspace.shift(3*RIGHT)
        self.play(
            FadeIn(targetspace),
        )
        DOOR_OFFSET = LEFT * 4
        
        corridor = Rectangle(
            height=6, width=4,
            color=BLACK,
            # stroke_color=WHITE,
            fill_color=BLACK,
            fill_opacity=1.0
        ).shift(DOOR_OFFSET+LEFT*3.5).set_z_index(1)

        door = Rectangle(height=6, width=3).set_z_index(3).shift(DOOR_OFFSET)

        self.add(corridor)
        self.play(
            FadeIn(door),
        )

        mean = np.array([2, 3.5]).reshape(-1,1)
        covar = np.array([1**2, 1*1.5*0.8, 1*1.5*0.8, 1.5**2]).reshape(2,2) * 0.25
        (covar)
        chol = np.linalg.cholesky(covar)
        vals = mean + chol @ rng.normal(size=(2,100))
        point_list = []

        for i, (vals) in enumerate(vals.T):
            person_shape = np.zeros(3)
            person_shape[:2] = vals
            width, height = person_shape[:2]
            # # should be the same
            axis_person_shape = targetspace.c2p(width, height) - targetspace.c2p(0, 0)

            lines = targetspace.get_lines_to_point(targetspace.c2p(width, height))
            # point = Point(targetspace.c2p(width, height))  # draws nothing?
            point = Dot(targetspace.c2p(width, height))
            point_list.append(point)

            if i < 20:
                person = create_person(
                    width, height
                ).shift(DOOR_OFFSET + LEFT*3)
                self.add(person)
                self.play(
                    MoveAlongPath(
                        person,
                        Line(person.get_center(), targetspace.c2p(width/2, height/2))
                    )
                )

                self.play(
                    Write(
                        lines,
                    ),
                    run_time=0.5
                )
                self.play(Indicate(point))
                self.play(
                    FadeOut(person),
                    FadeOut(lines),
                    run_time=0.5
                )
            else:
                self.play(
                    FadeIn(
                        point,
                    ),
                    run_time=0.1
                )
                self.play(Indicate(point), run_time=0.1)




KeyboardInterrupt: 

In [4]:
%%manim -v WARNING --progress_bar None NG_01_05_ContinuousModels
SEED = 15

class NG_01_05_ContinuousModels(Scene):
    def construct(self):
        seed(SEED)
        rng = np.random.default_rng(SEED)

        N_POP = 100
        FONT_SIZE = 45

        targetspace = Axes(
            y_range=[0, 6, 1.0],
            x_range=[0, 5, 1.0],
            y_length=6,
            x_length=5,
            # x_axis_config={"font_size": 36},
        )
        targetspace.shift(3*RIGHT).set_z_index(4)
        # print(targetspace.get_left()[0], targetspace.get_right()[0])
        # print(targetspace.get_bottom()[1], targetspace.get_top()[1])

        mean = np.array([2, 3.5]).reshape(-1,1)
        covar = np.array([1**2, 1*1.5*0.8, 1*1.5*0.8, 1.5**2]).reshape(2,2) * 0.25
        (covar)
        chol = np.linalg.cholesky(covar)
        vals = mean + chol @ rng.normal(size=(2,100))
        t_bottom = 0
        t_left = 0.0
        t_right = 3
        t_top = 4

        box = SVGMobject(
                "noun-shirt-3972412.svg",
                fill_color=GREEN,
                fill_opacity=0.5,
                width=t_right, height=t_top,
            ).set_z_index(-1).stretch_to_fit_width(t_right).stretch_to_fit_height(t_top)
        t_shirt_box = VGroup(
            t_shirt,
            Rectangle(
                height=t_top, width=t_right,
                stroke_color=GREEN
            )
        ).move_to(targetspace.c2p(t_right/2+t_left, t_top/2+t_bottom))

        def in_t_shirt(obj):
            # x, y = targetspace.p2c(obj.get_center())
            # isin =  t_left < x < t_right and t_bottom < y < t_top
            # return isin
            x, y, *_ = obj.get_center()
            isin =  t_shirt_box.get_left()[0] < x < t_shirt_box.get_right()[0] and t_shirt_box.get_bottom()[1]  < y < t_shirt_box.get_top()[1]
            return isin

        def update_t_shirt(new_top, new_right, new_left=0, new_bottom=0, run_time = 1):
            new_width = new_right - new_left
            new_height = new_top - new_bottom
            t_shirt_box.generate_target()
            t_shirt_box.target.stretch_to_fit_width(new_width).stretch_to_fit_height(new_height)
            t_shirt_box.target.move_to(targetspace.c2p(new_width/2+t_left, new_height/2+t_bottom))
            self.play(MoveToTarget(t_shirt_box), run_time=run_time)

        def color_in_t_shirt(obj):
            if in_t_shirt(obj):
                obj.set_color(GREEN)
            else:
                obj.set_color(WHITE)
        
        self.play(
            FadeIn(targetspace),
        )
        point_list = []
        for i, (vals) in enumerate(vals.T):
            person_shape = np.zeros(3)
            person_shape[:2] = vals
            width, height = person_shape[:2]
            # # should be the same
            axis_person_shape = targetspace.c2p(width, height) - targetspace.c2p(0, 0)

            lines = targetspace.get_lines_to_point(targetspace.c2p(width, height))
            point = Dot(targetspace.c2p(width, height)).set_z_index(1)
            color_in_t_shirt(point)
            point.add_updater(color_in_t_shirt)
            point_list.append(point)
        point_group = VGroup(*point_list)
        self.play(FadeIn(point_group, t_shirt_box))
        self.wait()

        t_shirt_prob_text = MathTex(
            # r"\mathbb{P}\left( \text{Shirt fits} | {{ \quad }} \right) = \frac{ {{ 0 }} }{100}",
            r"\mathbb{P}\left( \text{Shirt fits}| {{ WWW }} \right) = {{ \frac{0 }{100} }}",
            font_size=FONT_SIZE,)
        t_shirt_prob_text.shift(LEFT*3.25 + UP *2)
        p_l, p_cond, p_r, p_frac = t_shirt_prob_text
        mini_t_shirt = t_shirt.copy().match_width(p_cond).stretch_to_fit_height(1.8).move_to(p_cond)
        # p_cond.become(mini_t_shirt)
        p_frac.add_updater(
            lambda m: m.become(
                MathTex(
                    r"\frac{ %d }{100}" % sum(in_t_shirt(p) for p in point_list), font_size=FONT_SIZE
                ).move_to(p_frac)
            )
        )
        self.play(FadeIn(p_l, mini_t_shirt, p_r, p_frac))

        self.wait()

        update_t_shirt(4, 3, run_time=2)

        self.wait()

        update_t_shirt(1, 5, run_time=2)
        self.wait()

        update_t_shirt(5, 1, run_time=2)
        self.wait()

        update_t_shirt(4, 5, run_time=2)
        self.wait(5)
        # t_shirt_box.set_z_index(1)
        ## This nightmare is how we get a density field in manim:
        rv = sp.stats.multivariate_normal(mean=mean.ravel(), cov=covar)
        pos = np.dstack(np.meshgrid(np.linspace(0, 5, 100), np.linspace(0, 6, 100)))
        density = rv.pdf(pos)
        density /= density.max()
        density *= 255 * 1.25
        density = np.flip(density, 0)
        density = np.minimum(density, 255)
        density = density.astype(np.uint8) 
        # Was hoping this would add an alpha channel, doesn't seem to though
        density = np.repeat(np.expand_dims(density, -1), 4, axis=2)
        densityim = ImageMobject(density ).shift(3*RIGHT)
        densityim.set_resampling_algorithm(RESAMPLING_ALGORITHMS["cubic"])
        densityim.stretch_to_fit_height(6).stretch_to_fit_width(5).set_z_index(1).align_to(targetspace.c2p(0,0), DL).set_opacity(0.5)
        # print(densityim.get_height(), densityim.get_width())
        # print(densityim.get_left()[0], densityim.get_right()[0])
        # print(densityim.get_bottom()[1], densityim.get_top()[1])
        self.play(FadeIn(densityim))
        self.wait()

        # t_shirt_pdf_text = MathTex(
        #     r"\mathbb{P}\left( \text{Shirt fits}| {{ WWW }} \right) = {{ \frac{0 }{100} }}",
        #     font_size=FONT_SIZE,)
        # t_shirt_prob_text.shift(LEFT*3.25 + UP *2)
        # p_l, p_cond, p_r, p_frac = t_shirt_prob_text
        # mini_t_shirt = t_shirt.copy().match_width(p_cond).stretch_to_fit_height(1.8).move_to(p_cond)
        # # p_cond.become(mini_t_shirt)
        # p_frac.add_updater(
        #     lambda m: m.become(
        #         MathTex(
        #             r"\frac{ %d }{100}" % sum(in_t_shirt(p) for p in point_list), font_size=FONT_SIZE
        #         ).move_to(p_frac)
        #     )
        # )
        # self.play(FadeIn(p_l, mini_t_shirt, p_r, p_frac))



In [23]:
# %%manim -v WARNING --progress_bar None NG_01_06_ContinuousModels
# SEED = 15


# # 3d scene ends up being hard and slow, and so thi is not used

# class NG_01_06_ContinuousModels(ThreeDScene):
#     def construct(self):
#         mean = np.array([2, 3.5]).reshape(-1,1)
#         covar = np.array([1**2, 1*1.5*0.8, 1*1.5*0.8, 1.5**2]).reshape(2,2) * 0.25
#         rv = sp.stats.multivariate_normal(
#             mean=mean.ravel(), cov=covar)
#         z_height = rv.pdf(mean.ravel())

#         seed(SEED)
#         rng = np.random.default_rng(SEED)

#         N_POP = 100
#         FONT_SIZE = 45
#         resolution_fa = 32

#         self.set_camera_orientation(
#             phi=180 * DEGREES,
#             theta=0 * DEGREES,
#             zoom=0.3,)
#         axes = ThreeDAxes(
#             x_range=(0, 5, 1),
#             y_range=(0, 6, 1),
#             z_range=(0, z_height, 0.5),
#             x_length=10.5, y_length=10.5, z_length=z_height
#         )


#         def param_gauss(u, v):
#             x = u
#             y = v
#             return(rv.pdf([x, y]))
# # 
#         gauss_plane = Surface(
#             lambda u, v: axes.c2p(u, v, param_gauss(u, v)),
#             resolution=(resolution_fa, resolution_fa),
#             v_range=[0, 5],
#             u_range=[0, 6]
#         )

#         gauss_plane.scale(2, about_point=ORIGIN)
#         gauss_plane.set_style(fill_opacity=1,stroke_color=GREEN)
#         gauss_plane.set_fill_by_value(
#             axes=axes, colorscale=[(BLACK, 0.0), (WHITE, z_height)],
#             axis=2)
#         self.add(axes, gauss_plane)


#         chol = np.linalg.cholesky(covar)
#         vals = mean + chol @ rng.normal(size=(2,100))

#         self.camera.should_apply_shading = False

# #         t_shirt = SVGMobject(
# #                 "noun-shirt-3972412.svg",
# #                 fill_color=GREEN,
# #                 fill_opacity=0.5,
# #                 width=t_right, height=t_top,
# #             ).set_z_index(-1).stretch_to_fit_width(t_right).stretch_to_fit_height(t_top)
# #         t_shirt_box = VGroup(
# #             t_shirt,
# #             Rectangle(
# #                 height=t_top, width=t_right,
# #                 stroke_color=GREEN
# #             )
# #         ).move_to(targetspace.c2p(t_right/2+t_left, t_top/2+t_bottom))

# #         def in_t_shirt(obj):
# #             # x, y = targetspace.p2c(obj.get_center())
# #             # isin =  t_left < x < t_right and t_bottom < y < t_top
# #             # return isin
# #             x, y, *_ = obj.get_center()
# #             isin =  t_shirt_box.get_left()[0] < x < t_shirt_box.get_right()[0] and t_shirt_box.get_bottom()[1]  < y < t_shirt_box.get_top()[1]
# #             return isin

# #         def update_t_shirt(new_top, new_right, new_left=0, new_bottom=0, run_time = 1):
# #             new_width = new_right - new_left
# #             new_height = new_top - new_bottom
# #             t_shirt_box.generate_target()
# #             t_shirt_box.target.stretch_to_fit_width(new_width).stretch_to_fit_height(new_height)
# #             t_shirt_box.target.move_to(targetspace.c2p(new_width/2+t_left, new_height/2+t_bottom))
# #             self.play(MoveToTarget(t_shirt_box), run_time=run_time)

# #         def color_in_t_shirt(obj):
# #             if in_t_shirt(obj):
# #                 obj.set_color(GREEN)
# #             else:
# #                 obj.set_color(WHITE)
        
# #         self.play(
# #             FadeIn(targetspace),
# #         )
# #         point_list = []
# #         for i, (vals) in enumerate(vals.T):
# #             person_shape = np.zeros(3)
# #             person_shape[:2] = vals
# #             width, height = person_shape[:2]
# #             # # should be the same
# #             axis_person_shape = targetspace.c2p(width, height) - targetspace.c2p(0, 0)

# #             lines = targetspace.get_lines_to_point(targetspace.c2p(width, height))
# #             point = Dot(targetspace.c2p(width, height)).set_z_index(1)
# #             color_in_t_shirt(point)
# #             point.add_updater(color_in_t_shirt)
# #             point_list.append(point)
# #         point_group = VGroup(*point_list)
# #         self.play(FadeIn(point_group, t_shirt_box))
# #         self.wait()

# #         t_shirt_prob_text = MathTex(
# #             # r"\mathbb{P}\left( \text{Shirt fits} | {{ \quad }} \right) = \frac{ {{ 0 }} }{100}",
# #             r"\mathbb{P}\left( \text{Shirt fits}| {{ WWW }} \right) = {{ \frac{0 }{100} }}",
# #             font_size=FONT_SIZE,)
# #         t_shirt_prob_text.shift(LEFT*3.25 + UP *2)
# #         p_l, p_cond, p_r, p_frac = t_shirt_prob_text
# #         mini_t_shirt = t_shirt.copy().match_width(p_cond).stretch_to_fit_height(1.8).move_to(p_cond)
# #         # p_cond.become(mini_t_shirt)
# #         p_frac.add_updater(
# #             lambda m: m.become(
# #                 MathTex(
# #                     r"\frac{ %d }{100}" % sum(in_t_shirt(p) for p in point_list), font_size=FONT_SIZE
# #                 ).move_to(p_frac)
# #             )
# #         )
# #         self.play(FadeIn(p_l, mini_t_shirt, p_r, p_frac))

# #         self.wait()

# #         update_t_shirt(4, 3, run_time=2)

# #         self.wait()

# #         update_t_shirt(1, 5, run_time=2)
# #         self.wait()

# #         update_t_shirt(5, 1, run_time=2)
# #         self.wait()

# #         update_t_shirt(4, 5, run_time=2)
# #         self.wait(5)
# #         # t_shirt_box.set_z_index(1)
# #         ## This nightmare is how we get a density field in manim:
# #         rv = sp.stats.multivariate_normal(mean=mean.ravel(), cov=covar)
# #         pos = np.dstack(np.meshgrid(np.linspace(0, 5, 100), np.linspace(0, 6, 100)))
# #         density = rv.pdf(pos)
# #         density /= density.max()
# #         density *= 255 * 1.25
# #         density = np.flip(density, 0)
# #         density = np.minimum(density, 255)
# #         density = density.astype(np.uint8) 
# #         # Was hoping this would add an alpha channel, doesn't seem to though
# #         density = np.repeat(np.expand_dims(density, -1), 4, axis=2)
# #         densityim = ImageMobject(density ).shift(3*RIGHT)
# #         densityim.set_resampling_algorithm(RESAMPLING_ALGORITHMS["cubic"])
# #         densityim.stretch_to_fit_height(6).stretch_to_fit_width(5).set_z_index(1).align_to(targetspace.c2p(0,0), DL).set_opacity(0.5)
# #         # print(densityim.get_height(), densityim.get_width())
# #         # print(densityim.get_left()[0], densityim.get_right()[0])
# #         # print(densityim.get_bottom()[1], densityim.get_top()[1])
# #         self.play(FadeIn(densityim))
# #         self.wait()

# #         # t_shirt_pdf_text = MathTex(
# #         #     r"\mathbb{P}\left( \text{Shirt fits}| {{ WWW }} \right) = {{ \frac{0 }{100} }}",
# #         #     font_size=FONT_SIZE,)
# #         # t_shirt_prob_text.shift(LEFT*3.25 + UP *2)
# #         # p_l, p_cond, p_r, p_frac = t_shirt_prob_text
# #         # mini_t_shirt = t_shirt.copy().match_width(p_cond).stretch_to_fit_height(1.8).move_to(p_cond)
# #         # # p_cond.become(mini_t_shirt)
# #         # p_frac.add_updater(
# #         #     lambda m: m.become(
# #         #         MathTex(
# #         #             r"\frac{ %d }{100}" % sum(in_t_shirt(p) for p in point_list), font_size=FONT_SIZE
# #         #         ).move_to(p_frac)
# #         #     )
# #         # )
# #         # self.play(FadeIn(p_l, mini_t_shirt, p_r, p_frac))



In [30]:
%%manim -v WARNING --progress_bar None NG_01_07_ContinuousModelsBinned
SEED = 15

class NG_01_07_ContinuousModelsBinned(Scene):
    def construct(self):
        seed(SEED)
        rng = np.random.default_rng(SEED)

        N_POP = 100
        FONT_SIZE = 45

        targetspace = Axes(
            y_range=[0, 6, 1.0],
            x_range=[0, 5, 1.0],
            y_length=6,
            x_length=5,
            # x_axis_config={"font_size": 36},
        )
        targetspace.shift(3*RIGHT).set_z_index(4)

        mean = np.array([2, 3.5]).reshape(-1,1)
        covar = np.array([1**2, 1*1.5*0.8, 1*1.5*0.8, 1.5**2]).reshape(2,2) * 0.25
        (covar)
        chol = np.linalg.cholesky(covar)
        vals = mean + chol @ rng.normal(size=(2,100))
        rv = sp.stats.multivariate_normal(mean=mean.ravel(), cov=covar)

        point_list = []

        def create_box_on_axes(x, y, size, color=BLUE):
            w, h, *_ = targetspace.c2p(size, size) - targetspace.c2p(0, 0)
            box = Rectangle(
                width=w,
                height=h,
                color=color,
            ).set_z_index(2)
            text = Text(r" ").set_z_index(3)
            labeled_box = VGroup(
                text,
                box,
            ).set_z_index(5)
            labeled_box.move_to(targetspace.c2p(x, y), DL)
            return labeled_box
        
        def count_points_in_box(box):
            count = sum(in_box(p, box) for p in point_list)
            text = box[0]
            text.become(
                count_text_points_in_box(box))

        def count_text_points_in_box(labelbox):
            box = labelbox[1]
            count = sum(in_box(p, box) for p in point_list)
            if count==0:
                text =  Text(r" ").set_z_index(3)
            else:
                text =  MathTex(
                    r"\mathbf{\frac{%d}{%d}}" % (count, N_POP)
                ).set_z_index(5).scale_to_fit_height(0.7 * box.get_height())
            text.move_to(box)
            text.color = box.get_color()
            text.fill_color = box.get_color()
            text.stroke_color = box.get_color()
            return text


        def in_box(obj, box):
            x, y, *_ = obj.get_center()
            isin = (
                box.get_left()[0] < x < box.get_right()[0] and
                box.get_bottom()[1]  < y < box.get_top()[1])
            return isin

        self.play(
            FadeIn(targetspace),
        )
        
        for i, (vals) in enumerate(vals.T):
            person_shape = np.zeros(3)
            person_shape[:2] = vals
            width, height = person_shape[:2]
            # # should be the same
            axis_person_shape = targetspace.c2p(width, height) - targetspace.c2p(0, 0)

            lines = targetspace.get_lines_to_point(targetspace.c2p(width, height))
            point = Dot(
                targetspace.c2p(width, height),
                fill_opacity=0.5, radius=0.05).set_z_index(1)
            point_list.append(point)
        
        point_group = VGroup(*point_list)
        boxes = []
        for x in range(5):
            for y in range(6):
                box = create_box_on_axes(x, y, 1)
                boxes.append(box)
        box_group = VGroup(*boxes)
        
        self.play(Write(point_group))
        self.wait()
        self.play(Write(box_group))
        
        text_group = VGroup(*[
            count_text_points_in_box(box)
            for box in box_group])
        self.play(Write(text_group))
        
        self.wait()

        self.play(Unwrite(box_group))
        self.play(Unwrite( text_group))
        boxes = []
        for x in range(10):
            for y in range(12):
                box = create_box_on_axes(x/2, y/2, 0.5)
                boxes.append(box)
        box_group = VGroup(*boxes)
        self.wait()
        self.play(Write(box_group))
        
        # box.add_updater(count_points_in_box)
        self.play(Write(VGroup(*[
            count_text_points_in_box(box)
            for box in box_group])))

        self.wait(3 )

        # ## This nightmare is how we get a density field in manim:
        # pos = np.dstack(np.meshgrid(np.linspace(0, 5, 100), np.linspace(0, 6, 100)))
        # density = rv.pdf(pos)
        # density /= density.max()
        # density *= 255 * 1.25
        # density = np.flip(density, 0)
        # density = np.minimum(density, 255)
        # density = density.astype(np.uint8) 
        # # Was hoping this would add an alpha channel, doesn't seem to though
        # density = np.repeat(np.expand_dims(density, -1), 4, axis=2)
        # densityim = ImageMobject(density ).shift(3*RIGHT)
        # densityim.set_resampling_algorithm(RESAMPLING_ALGORITHMS["cubic"])
        # densityim.stretch_to_fit_height(6).stretch_to_fit_width(5).set_z_index(1).align_to(targetspace.c2p(0,0), DL).set_opacity(0.5)
        # # print(densityim.get_height(), densityim.get_width())
        # # print(densityim.get_left()[0], densityim.get_right()[0])
        # # print(densityim.get_bottom()[1], densityim.get_top()[1])
        # self.play(FadeIn(densityim))


