---
title: "Writing Stable Diffusion from Scratch 9"
reading-time: 
date: "2023-3-29"
categories: [fastaipart2,Stable-Diffusion]
---

All credits goes to fast.ai 
All mistakes are mine. 

You should know and practice following after this blog post : <br>
1- Callbacks , callable class <br>
2- Partial <br>
3- Lambda  <br>
4- __dunder__ thingies <br>

In [51]:
import torch
import matplotlib.pyplot as plt
import random

So we've kind of got nearly got all of our infrastructure in place before we do this. Some pieces of python, which not everybody knows and I want to kind of talk about and kind of computer science concepts I want to talk about. So that's what our six foundations is about. So this whole section is just going to tell. It is going to talk about some stuff in Python that you might not have come across before, or maybe it's a review for some of you as well. And it's all stuff we're going to be using basically in the next notebook.

So that's why I want to talk to cover it. So we're going to be creating a learner class. So a learner class is going to be a very general purpose training loop, which we can get to to do anything that we want it to do. And we're going to be creating things called callbacks to make that happen. And so therefore we're going to spend a few moments talking about what are callbacks, how are they used in in computer science, how are they implemented? Look at some examples. They come up a lot. 

Perhaps the most common place that you see callbacks in software is for doing events of events from some graphical user interface. So the main graphical user interface library in Jupyter notebooks is called ipywidgets, and we can create a widget like a pattern . And when we display it, it shows me a button and at the moment it doesn't do anything. If I click on it. What we can do though, is we can add and onclick callback to it,we're going to pass it a function which is called when you click it. So to find that function. So I'm going to say w.on_click(f) is going to assign the f function to the on_click callback. Now, if I click this, there you go. It's doing it. Now, what does that mean? 

Well, a callback is simply a callable that you've provided. So remember, a callable is a more general version of a function. So in this case, it is a function that you've provided that will be called back to when something happens. So in this case, so something that's happening is that they're clicking a button. So this is how we are defining and using a callback as a GUI event. So basically everything in ipywidgets, if you want to create your own graphical user interfaces for Jupyter, you can do it with ipy widgets and by using these callbacks. So these particular kinds of callbacks are called events, but it's just a callback. All right, so that's somebody else's callback. 



## Callbacks

### Callbacks as GUI events

In [52]:
import ipywidgets as widgets

From the [ipywidget docs](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20Events.html):

- *the button widget is used to handle mouse clicks. The on_click method of the Button can be used to register function to be called when the button is clicked*

In [53]:
w = widgets.Button(description='Click me')

In [54]:
w

Button(description='Click me', style=ButtonStyle())

In [55]:
def f(o): print('hi')

In [56]:
w.on_click(f)

*NB: When callbacks are used in this way they are often called "events".*

Let's create our own callback. So let's say we've some very slow calculation, and so it takes a very long time to add up the numbers 0 to 5 squared because we sleep for a second after each one. So let's run our slow calculation. Still running. Oh, how's it going? Come on, finish our calculation. 

There we go. The answer is 30. Now, for a slow calculation like that, such as training, a model, it's a slow calculation. It'll be nice to do things like, I don't know, you know, print out loss from time to time or show a progress bar or whatever. So generally, for those kinds of things, we would like to define a callback that is called at the end of each epoch or batch or every few seconds or something like that. 



### Creating your own callback

In [57]:
from time import sleep

In [58]:
def slow_calculation():
    res = 0
    for i in range(5):
        res += i*i
        sleep(1)
    return res

In [59]:
slow_calculation()

30

So here's how we can modify our calculation routine such that you can optionally pass it a callback. And so all of these codes are the same, except we've added this one line of code that says if there's a callback, then call it and pass in what what we're up to. So then we could create our callback function. So this is just like we created a our callback function f let's create a show_progress callback function that's going to tell us how far we've got. So now if we call slow calculation passing in our callback, you can say it's going to call this function at the end of each step. 

So here we've created our own callback so there's nothing special about a callback. Like it doesn't require its own like syntax. It's not a new concept, it's just an idea really, which is the idea of passing in a function which some other function will call at particular times, such as at the end of a step or such as when you click a button. So that's what we mean by callbacks. 

We don't have to define the function ahead of time. We could define the function at the same time that we call the slow calculation by using Lambda. So as we've discussed before, Lambda just defines a function, but it doesn't give it a name. So here's a function that takes one parameter and prints out exactly the same thing as before. So here's the same way as doing it, but using a lambda, we could make it more sophisticated now and rather than always saying also we finished epoc, whatever we could have let you pass in an exclamation and we print that out. And so in this case, we could now have our lambda call that function. And so one of the things that we can do now is to again, we can create a function that returns a function. And so we could create a     make_ _show_progress function where you pass in the exclamation mark. We could then create in this no need to give it a name. 
it's just return it directly. We can return a function that calls that exclamation. So here we are passing in Nice, and that's exactly the same as doing something like what we've done before. We could say instead of using a lambda we can create in a function like this. So here is now a function that returns, a function that says exactly the same thing. Okay, so one way with the lambda when we're allowed to lambda and one of the reasons I wanted to show you that is so  we can do exactly the same thing using partial. So with partial, it's going to do exactly the same thing as this kind of makes show progress. It's going to call, show progress and pass. Okay. I guess so is again an example of a function returning a function. And so this is a function that calls show progress, passing in this as the first parameter. And Again, it does exactly the same thing. Okay. So where you get we tend to use partial a lot. So that's certainly something worth spending time practicing. Now, as we've discussed, Python doesn't care about types in particular, and there's nothing about any of this that requires cb to be a function.



In [60]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        res += i*i
        sleep(1)
        if cb: cb(i)
    return res

In [61]:
def show_progress(epoch): print(f"Awesome! We've finished epoch {epoch}!")

In [62]:
slow_calculation(show_progress)

Awesome! We've finished epoch 0!
Awesome! We've finished epoch 1!
Awesome! We've finished epoch 2!
Awesome! We've finished epoch 3!
Awesome! We've finished epoch 4!


30

### Lambdas and partials

In [63]:
slow_calculation(lambda o: print(f"Awesome! We've finished epoch {o}!"))

Awesome! We've finished epoch 0!
Awesome! We've finished epoch 1!
Awesome! We've finished epoch 2!
Awesome! We've finished epoch 3!
Awesome! We've finished epoch 4!


30

In [64]:
def show_progress(exclamation, epoch): print(f"{exclamation}! We've finished epoch {epoch}!")

In [65]:
slow_calculation(lambda o: show_progress("OK I guess", o))

OK I guess! We've finished epoch 0!
OK I guess! We've finished epoch 1!
OK I guess! We've finished epoch 2!
OK I guess! We've finished epoch 3!
OK I guess! We've finished epoch 4!


30

In [66]:
def make_show_progress(exclamation):
    def _inner(epoch): print(f"{exclamation}! We've finished epoch {epoch}!")
    return _inner

In [67]:
slow_calculation(make_show_progress("Nice!"))

Nice!! We've finished epoch 0!
Nice!! We've finished epoch 1!
Nice!! We've finished epoch 2!
Nice!! We've finished epoch 3!
Nice!! We've finished epoch 4!


30

In [68]:
from functools import partial

In [69]:
slow_calculation(partial(show_progress, "OK I guess"))

OK I guess! We've finished epoch 0!
OK I guess! We've finished epoch 1!
OK I guess! We've finished epoch 2!
OK I guess! We've finished epoch 3!
OK I guess! We've finished epoch 4!


30

In [70]:
f2 = partial(show_progress, "OK I guess")

It just happens to be it just has to be a callable. 

A callable is something that that you can that you can call. And so as we've discussed another way of creating a callable is defining to__call__.

 So here's a class and this is going to work exactly the same as our make show progress thing but now as a class so there's a __init__ which store the explanation and __call__ the prints and so now we're creating a object which is callable and does exactly the same thing 
 
so these are all like fundamental ideas that I want you to get really comfortable with the idea of __call__ , dunder things in general, partials, classes because they come up all the time in PyTorch code and, and in the code we'll be writing and, in fact, pretty much all frameworks. So it's really important to feel comfortable with them. And remember, you don't have to rely on the resources we're providing, you know, if there are certain things here that are very new to you, you know, Google around for some tutorials, so ask for help in the forums, for finding things and so forth. 



### Callbacks as callable classes

In [71]:
class ProgressShowingCallback():
    def __init__(self, exclamation="Awesome"): self.exclamation = exclamation
    def __call__(self, epoch): print(f"{self.exclamation}! We've finished epoch {epoch}!")

In [72]:
cb = ProgressShowingCallback("Just super")

In [73]:
slow_calculation(cb)

Just super! We've finished epoch 0!
Just super! We've finished epoch 1!
Just super! We've finished epoch 2!
Just super! We've finished epoch 3!
Just super! We've finished epoch 4!


30

### Multiple callback funcs; `*args` and `**kwargs`

And then I'm just going to briefly recover something I've mentioned before, which is *args, **kwargs because again, they come up a lot. I just want to show you how they work. So if we create a function that has *args and **kwargs, nothing else, and I'm just going to this function, just print them now, I'm going to call the function, I'm going to pass three, I'm going to pass a and I'm going to pass thing one equals.(f(3, 'a', thing1="hello")) Hello. Now, these are past what we would say by position. We haven't got a block equals. They're just stuck. They're things that are passed by position are placed in *args if you have one. It doesn't have to be called args, you can call it anything you like but in the star bit. 

And so you can see here that args is a tuple containing the positionally path documents. 

And then kwargs is a dictionary containing the name arguments. So that is all that *args and **kwargs does. And as I say, there's nothing special about these names. I call this a I'll call this b, okay. And it'll do exactly the same thing. 

def f(*a, **b): print(f"args: {a}; kwargs: {b}")

Okay, so this comes up a lot. And so it's it's important to remember that this is literally all that they're doing. And then on the other hand, let's say we had a function which takes a couple of let's try that print. I actually just put them directly a, b, c, okay. We can also, rather than just using them as parameters, we can also use some of them when calling something. So let's say I create something called args again. It doesn't have to be called args called, which contains one comma two. And I create something called kwargs that contains a dictionary 

args = [1,2]
kwargs = {'c':3}

G and I can pass in star args,star star kwargs. And that's going to take this one two and pass them as individual arguments for positionally. And it's going to take the {'c':3} and pass that as a named argument. c equals three. And there it is. Okay, so they're kind of two linked but different ways that use star and star star. 



In [74]:
def f(*a, **b): print(f"args: {a}; kwargs: {b}")

In [75]:
f(3, 'a', thing1="hello")

args: (3, 'a'); kwargs: {'thing1': 'hello'}


In [76]:
def g(a,b,c=0): print(a,b,c)

In [77]:
args = [1,2]
kwargs = {'c':3}
g(*args, **kwargs)

1 2 3


In [78]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb: cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb: cb.after_calc(i, val=res)
    return res

Okay. Now here's a slightly different way of doing callbacks, which I really like in this case. I've now passing in a callback that's not callable, but instead it's going to have a method called before_calc and another method called after_calc. And I'm so now my callback is going to be a class containing a before_calc and after_calc method. And so if I run that, you can see that there it goes.

Okay. And so this is printing before and after every step by call, calling before_calc and after_calc. So callback actually doesn't have to be a callable. It doesn't have to be a function. A callback could be something that contains methods. So we could have a version of this which actually, as you can see here, it's going to pass int after_calc, both the epoch number and the value it's up to. But by using star args and star star kwargs I can just safely ignore them if I don't want them. Right? So it's just going to chew them up and not complain. 

If I didn't have those here, it won't work because it got passed in value equals and there's nothing here looking for val equals that doesn't like that. So this is one good use star args and star star kwargs eat up arguments You don't want. 

Or we could use the argument. So let's actually use epoch and Val and print them out and there it is. So this is a more sophisticated callback that's giving us status as we go. 


In [79]:
class PrintStepCallback():
    def before_calc(self, *args, **kwargs): print(f"About to start")
    def after_calc (self, *args, **kwargs): print(f"Done step")

In [80]:
slow_calculation(PrintStepCallback())

About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step


30

In [81]:
class PrintStatusCallback():
    def __init__(self): pass
    def before_calc(self, epoch, **kwargs): print(f"About to start: {epoch}")
    def after_calc (self, epoch, val, **kwargs): print(f"After {epoch}: {val}")

In [82]:
slow_calculation(PrintStatusCallback())

About to start: 0
After 0: 0
About to start: 1
After 1: 1
About to start: 2
After 2: 5
About to start: 3
After 3: 14
About to start: 4
After 4: 30


30

### Modifying behavior

In [83]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb,'before_calc'): cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb and hasattr(cb,'after_calc'):
            if cb.after_calc(i, res):
                print("stopping early")
                break
    return res

In [84]:
class PrintAfterCallback():
    def after_calc (self, epoch, val):
        print(f"After {epoch}: {val}")
        if val>10: return True

In [85]:
slow_calculation(PrintAfterCallback())

After 0: 0
After 1: 1
After 2: 5
After 3: 14
stopping early


14

In [86]:
class SlowCalculator():
    def __init__(self, cb=None): self.cb,self.res = cb,0
    
    def callback(self, cb_name, *args):
        if not self.cb: return
        cb = getattr(self.cb,cb_name, None)
        if cb: return cb(self, *args)

    def calc(self):
        for i in range(5):
            self.callback('before_calc', i)
            self.res += i*i
            sleep(1)
            if self.callback('after_calc', i):
                print("stopping early")
                break

In [87]:
class ModifyingCallback():
    def after_calc (self, calc, epoch):
        print(f"After {epoch}: {calc.res}")
        if calc.res>10: return True
        if calc.res<3: calc.res = calc.res*2

In [88]:
calculator = SlowCalculator(ModifyingCallback())

In [89]:
calculator.calc()
calculator.res

After 0: 0
After 1: 1
After 2: 6
After 3: 15
stopping early


15

## `__dunder__` thingies

Anything that looks like `__this__` is, in some way, *special*. Python, or some library, can define some functions that they will call at certain documented times. For instance, when your class is setting up a new object, python will call `__init__`. These are defined as part of the python [data model](https://docs.python.org/3/reference/datamodel.html#object.__init__).

For instance, if python sees `+`, then it will call the special method `__add__`. If you try to display an object in Jupyter (or lots of other places in Python) it will call `__repr__`.


Okay, So finally, let's just review this idea dunder, which we've mentioned before, but just to, to really nail this home, anything that looks like this underscore, underscore something, underscore, underscore something is special. And basically it could be that Python has to find that special thing or PyTorch has to find that special thing or numpy as to find that special thing. But this special these are called under methods. And some of them are defined as part of the Python data model. 

And so if you go to the Python documentation, it'll tell you about these various different his __repr__ which we used earlier is __init__ that we used earlier. So they're all here. PyTorch has some of its own, numpy has some of its own.

 So for example, if python says plus what it actually does is it calls __add__. So if we want to create something that's not very good at adding things, it actually also always adds point. I want to it that I can say sloppy at a one plus floppy at a two equals 3.01. So plus here is actually calling __add__. So if you're not familiar with this, click on this data model link and read about these specific one two, three, four, five, six, seven, eight, nine, ten, 11 methods because we'll be using all of these in the course. 
 
So I'll try to revise them when we can. But I'm generally going to assume that, you know, a particularly interesting one is __getattr__ and __getitem__. We've seen __setattr__ already get across just the opposite. Take a look at this. Here's a class. It just contains two attributes a, b, that are set one and two. So create that an object of that class a.b equals two because I set b to two. Okay. Now when you say dot B, that's just in texture. Good. Basically in Python, what it's actually calling behind the scenes is __getattr__, it calls, __getattr__ on the object. And so this one here is the same __getattr__ a comma b which hopefully 

In [90]:
class SloppyAdder():
    def __init__(self,o): self.o=o
    def __add__(self,b): return SloppyAdder(self.o + b.o + 0.01)
    def __repr__(self): return str(self.o)

In [91]:
a = SloppyAdder(1)
b = SloppyAdder(2)
a+b

3.01

Special methods you should probably know about (see data model link above) are:

- `__getitem__`
- `__getattr__`
- `__setattr__`
- `__del__`
- `__init__`
- `__new__`
- `__enter__`
- `__exit__`
- `__len__`
- `__repr__`
- `__str__`

### `__getattr__` and `getattr`

In [92]:
class A: a,b=1,2

In [93]:
a = A()

In [94]:
a.b

2

In [95]:
getattr(a, 'b')

2

I'll actually that'll be yes our calls get a b and this can kind of be fun because you could call, __getattr__ a comma and either b or a randomly . So if I run this 21112 as you can say, it's random. So yeah, Python such a dynamic language, you can even set it up so it literally don't know what attributes are going to be called. Now getattr behind the scenes. It's actually calling something called __getattr__ and by default it'll use the version in the object based class. So here's something just like a it's got a and b defined, but I've also got time getattr defined and so __getattr__ It's only called for stuff that hasn't been defined yet and it'll pass in the key of the the name of the attribute. So generally speaking, if the first character is an underscore, it's going to be private or special. So That's going to raise an attribute error. Otherwise I'm going to steal it and return hello from k. So if I go b.athat's defined so it gives me one. If I go 
b.foo, that's not defined. So calls __getattr__ and I get back hello from foo. And so this gets used a lot in both fastai code and also huggingface code to you know often make it more convenient to access things. So that's yeah that's how we getattr function and __getattr__ method work. 

In [96]:
getattr(a, 'b' if random.random()>0.5 else 'a')

1

In [97]:
class B:
    a,b=1,2
    def __getattr__(self, k):
        if k[0]=='_': raise AttributeError(k)
        return f'Hello from {k}'

In [98]:
b = B()

In [99]:
b.a

1

In [100]:
b.foo

'Hello from foo'

 Okay so I went over that pretty quickly since I know for quite a few folks this will be all review, but I know for folks who haven't seen any of this, this is a lot to cover. So I'm hoping that you all kind of go back over this, revise it slowly, experiment with it, and look up some additional resources and ask on the forum and stuff. That's not clear. Remember, everybody has parts of the course that's really easy for them and parts of the course that are completely unfamiliar for them. And so if this particular part of the course is completely unfamiliar to you, it's not because this is harder or going to be more difficult or whatever. It's just so happens that this is a bit that you're less familiar with. Or maybe this stuff about calculus in the last lesson was a bit that you're less familiar with. There isn't really anything, particularly in the course, that's more difficult than other parts. It's just that, you know, based on whether you happen to have that background. 
 
And so yeah, if you spend a few hours studying and practicing, you know, you'll be able to pick up these things and yeah, so don't stress if there are things that you don't get right away, just take the time. And if you Yeah, if you do get lost, please ask because people are very keen to help. If you've tried asking on the forum, hopefully you noticed that people are really keen to help. All right. So I think this has been a pretty successful lesson. We've we've got to a point where we've got a pretty nicely optimized training loop. We actually understand exactly what data load is and data sets do. We've got an optimizer. We've been playing with hugging face data sets and we've got those working really smoothly. So we really feel like we're in a pretty good position to to write our generic learner training loop and then we can start building and experimenting with lots of models. So look forward to seeing you next. Time to doing that together. Okay. 
