Skip to content

Commit

Permalink
Add initial code and docs
Browse files Browse the repository at this point in the history
commit f197aa8
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Sat Aug 26 12:00:02 2023 +0100

    update font paths

commit 03a4c81
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Fri Aug 25 19:47:19 2023 +0100

    touchups

commit 28d1fb7
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Fri Aug 25 15:13:42 2023 +0100

    add introductory api page texts

commit 1b13eef
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Thu Aug 24 20:04:57 2023 +0100

    add missing documentation. only need to fix a few things now

commit a35a5c4
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Wed Aug 23 18:18:42 2023 +0100

    complete noise schedule api page

commit 514b0b2
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Wed Aug 23 17:42:39 2023 +0100

    write most of the api reference pages

commit 0317c2a
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Mon Aug 21 17:14:27 2023 +0100

    reduce text max width

commit 7c783f0
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Mon Aug 21 17:11:59 2023 +0100

    improve mobile styles and add click on article to hide sidebar

commit 5501a95
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Mon Aug 21 17:01:40 2023 +0100

    change code font, make responsive for mobile, small fixes

commit fb8d697
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Sat Aug 19 20:30:20 2023 +0100

    make scroll event highlight current section

commit 02ae584
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Sat Aug 19 18:34:46 2023 +0100

    add index

commit f50914a
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Sat Aug 19 12:37:57 2023 +0100

    add menu transitions

commit 749c083
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Fri Aug 18 19:02:08 2023 +0100

    add logo and make it clickable

commit 807ce36
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Fri Aug 18 18:21:56 2023 +0100

    responsive sidebar

commit 5079bed
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Fri Aug 18 15:56:46 2023 +0100

    further improve styles, add sidebar and header

commit aa9ada4
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Wed Aug 16 22:53:24 2023 +0100

    more formatting changes

commit f442b04
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Wed Aug 16 15:41:10 2023 +0100

    new documentation styles

commit 62e8734
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Sun Aug 6 17:08:37 2023 +0100

    docs initial structure

commit 17276ef
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Fri Aug 4 12:29:49 2023 +0100

    docs minimal setup

commit 04f8463
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Fri Aug 4 12:21:08 2023 +0100

    code complete

commit 284bae6
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Thu Jul 27 16:57:56 2023 +0100

    big changes

commit 14e20ae
Merge: 5bb376f 9af0221
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Thu Jun 22 10:53:46 2023 +0100

    Merge branch 'main' of https://github.com/cabralpinto/modular-diffusion into main

commit 5bb376f
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Thu Jun 22 10:53:42 2023 +0100

    more

commit 9af0221
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Wed Jun 21 16:53:42 2023 +0100

    Update README.md

commit eeea7eb
Merge: b1c03f1 ca3f032
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Tue Jun 20 15:39:29 2023 +0100

    more stuff

commit b1c03f1
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Tue Jun 20 15:38:46 2023 +0100

    more stuff

commit ca3f032
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Mon Jun 19 18:22:39 2023 +0100

    Update README.md

commit e59ae98
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Mon Jun 19 18:22:05 2023 +0100

    Update README.md

commit 3cda2a7
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Mon Jun 19 18:18:51 2023 +0100

    Update README.md

commit ae40f70
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Mon Jun 19 18:16:00 2023 +0100

    Update README.md

commit 4f34168
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Mon Jun 19 12:03:19 2023 +0100

    update readme

commit 662caba
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Mon Jun 19 11:51:46 2023 +0100

    update readme

commit 34a359b
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Sun Jun 18 17:11:01 2023 +0100

    update readme

commit dc5b3cd
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Sun Jun 18 17:04:22 2023 +0100

    update readme

commit f1609ac
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Sun Jun 18 13:28:24 2023 +0100

    updated readme

commit 141a498
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Sat Jun 17 15:14:09 2023 +0100

    added mnist data, requirements.txt

commit 1f3c667
Author: cabralpinto <jmcabralpinto@gmail.com>
Date:   Fri Jun 16 12:11:01 2023 +0100

    added initial code and examples

commit 57b94d4
Author: João Cabral Pinto <47889626+cabralpinto@users.noreply.github.com>
Date:   Fri Jun 16 11:01:49 2023 +0100

    Initial commit
  • Loading branch information
cabralpinto committed Aug 26, 2023
1 parent 8bc82e9 commit 08a3bfd
Show file tree
Hide file tree
Showing 83 changed files with 9,559 additions and 201 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
__pycache__
env
nohup.out
private
examples/data
94 changes: 64 additions & 30 deletions diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import sys
from dataclasses import dataclass, field
from functools import partial
from itertools import chain
from pathlib import Path
from typing import Callable, Generic, Iterator, Optional, TypeVar

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Parameter
from torch.optim import AdamW, Optimizer
from torch.optim import Adam, Optimizer
from tqdm import tqdm

from . import data, guidance, loss, net, noise, schedule, time
from .base import Batch, Data, Distribution, Guidance, Loss, Net, Noise, Schedule, Time
from .time import Uniform
from .time import Discrete

__all__ = ["data", "loss", "net", "noise", "schedule", "time", "Model"]

Expand All @@ -25,64 +26,97 @@ class Model(Generic[D]):
noise: Noise[D]
loss: Loss[D]
net: Net
time: Time = field(default_factory=Uniform)
time: Time = field(default_factory=Discrete)
guidance: Optional[Guidance] = None # TODO remove hardcoding
optimizer: Optional[Optimizer | Callable[..., Optimizer]] = None
device: torch.device = torch.device("cuda")

@property
def parameters(self) -> Iterator[Parameter]:
return self.net.parameters() # TODO add all parameters
device: torch.device = torch.device("cpu")
compile: bool = True

@torch.no_grad()
def __post_init__(self):
self.noise.schedule(self.schedule.compute().to(self.device))
parameters = chain(self.data.parameters(), self.net.parameters())
if self.optimizer is None:
self.optimizer = partial(AdamW, lr=1e-4)
if callable(self.optimizer):
self.optimizer = self.optimizer(self.parameters)
self.net = self.net.to(self.device) # TODO: add all tensors to device
if sys.version_info <= (3, 10):
self.optimizer = Adam(parameters, lr=1e-4)
elif callable(self.optimizer):
self.optimizer = self.optimizer(parameters)
self.net = self.net.to(self.device)
for name, value in vars(self.data).items():
if isinstance(value, nn.Module):
setattr(self.data, name, value.to(self.device))
if self.compile and sys.version_info < (3, 11):
self.net = torch.compile(self.net) # type: ignore[union-attr]

@torch.no_grad()
def load(self, path: Path | str):
state = torch.load(path)
self.net.load_state_dict(state["net"])
for name, dict in state["data"].items():
getattr(self.data, name).load_state_dict(dict)

@torch.no_grad()
def save(self, path: Path | str):
state = {
"net": self.net.state_dict(),
"data": {
name: value.state_dict()
for name, value in vars(self.data).items()
if isinstance(value, nn.Module)
}
}
torch.save(state, path)

@torch.enable_grad()
def train(self, epochs: int = 1) -> Iterator[float]:
def train(self, epochs: int = 1, progress: bool = True) -> Iterator[float]:
self.net.train()
batch = Batch[D](self.device)
for _ in range(epochs):
for batch.w, batch.y in (bar := tqdm(self.data)):
bar = tqdm(self.data, disable=not progress)
for batch.w, batch.y in self.data:
if isinstance(self.guidance, guidance.ClassifierFree):
i = torch.randperm(batch.y.shape[0])
batch.y[i[:int(batch.y.shape[0] * self.guidance.dropout)]] = 0
batch.x = self.data.encode(batch.w)
batch.t = self.time(self.schedule.steps, batch.x.shape[0])
batch.t = self.time.sample(self.schedule.steps, batch.x.shape[0])
batch.z, batch.epsilon = self.noise.prior(batch.x, batch.t).sample()
batch.hat = self.net(batch.z, batch.y, batch.t)
batch.q = self.noise.posterior(batch.x, batch.z, batch.t)
batch.p = self.noise.approximate(batch.z, batch.t, batch.hat)
batch.l = self.loss(batch)
bar.set_postfix(loss=batch.l.item())
batch.l = self.loss.compute(batch)
self.optimizer.zero_grad() # type: ignore[union-attr]
batch.l.backward()
self.optimizer.step() # type: ignore[union-attr]
bar.set_postfix(loss=f"{batch.l.item():.2e}")
bar.update()
bar.close()
yield batch.l.item()

@torch.no_grad()
def sample(self, y: Optional[Tensor] = None, batch: int = 1) -> Tensor:
if y is None:
y = torch.zeros(batch, dtype=torch.int, device=self.device)
def sample(
self,
y: Optional[Tensor] = None,
batch: int = 1,
progress: bool = True,
) -> Tensor:
self.net.eval()
y = y.to(self.device)
pi = self.noise.isotropic(y.shape[0], *self.data.shape)
if y is None:
shape = 1, *(() if self.data.y is None else self.data.y.shape[1:])
y = torch.zeros(shape, dtype=torch.int, device=self.device)
y = y.repeat_interleave(batch, 0).to(self.device)
pi = self.noise.isotropic((y.shape[0], *self.data.shape))
z = pi.sample()[0].to(self.device)
l = torch.zeros(0, y.shape[0], *self.data.shape, device=self.device)
for t in tqdm(range(self.schedule.steps, 0, -1)):
t = torch.full_like(y, t)
l = self.data.decode(z)[None]
bar = tqdm(total=self.schedule.steps, disable=not progress)
for t in range(self.schedule.steps, 0, -1):
t = torch.full((batch,), t, device=self.device)
hat = self.net(z, y, t)
if isinstance(self.guidance, guidance.ClassifierFree):
s = self.guidance.weight
hat = s * hat + (1 - s) * self.net(z, torch.zeros_like(y), t)
s = self.guidance.scale
hat = (1 + s) * hat - s * self.net(z, torch.zeros_like(y), t)
z, _ = self.noise.approximate(z, t, hat).sample()
# print(z.min().item(), z.max().item(), flush=True)
w = self.data.decode(z)
l = torch.cat((l, w[None]), 0)
bar.update()
bar.close()
return l
42 changes: 26 additions & 16 deletions diffusion/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from itertools import chain
from typing import Any, Callable, Generic, Iterator, Optional, TypeVar

import torch
import torch.nn as nn
from torch import Tensor
from typing_extensions import Self

from .utils.nn import Sequential

__all__ = ["Batch", "Data", "Distribution", "Loss", "Net", "Noise", "Schedule", "Time"]


Expand Down Expand Up @@ -53,29 +56,33 @@ def __setattr__(self, prop: str, val: Any):

@dataclass
class Data(ABC):
x: Tensor
w: Tensor
y: Optional[Tensor] = None
batch: int = 1
shuffle: bool = False

@property
def shape(self) -> tuple[int]:
return self.x.shape[1:]
def shape(self) -> tuple[int, ...]:
return self.encode(self.w[:1]).shape[1:]

def parameters(self) -> Iterator[nn.Parameter]:
return chain.from_iterable(
var.parameters() for var in vars(self) if isinstance(var, nn.Module))

def __iter__(self) -> Iterator[tuple[Tensor, Tensor]]:
if self.y is None:
self.y = torch.zeros(self.x.shape[0], dtype=torch.int)
self.y = torch.zeros(self.w.shape[0], dtype=torch.int)
if self.shuffle:
index = torch.randperm(self.x.shape[0])
self.x, self.y = self.x[index], self.y[index]
self.data = zip(self.x.split(self.batch), self.y.split(self.batch))
index = torch.randperm(self.w.shape[0])
self.w, self.y = self.w[index], self.y[index]
self.data = zip(self.w.split(self.batch), self.y.split(self.batch))
return self

def __next__(self) -> tuple[Tensor, Tensor]:
return next(self.data)

def __len__(self) -> int:
return self.x.shape[0] // self.batch
return -(self.w.shape[0] // -self.batch)

@abstractmethod
def encode(self, w: Tensor) -> Tensor:
Expand All @@ -92,16 +99,14 @@ class Time(ABC):
def sample(self, steps: int, size: int) -> Tensor:
raise NotImplementedError

def __call__(self, steps: int, size: int) -> Tensor:
return self.sample(steps, size)


@dataclass
class Schedule(ABC):
steps: int

@abstractmethod
def compute(self) -> Tensor:
"""Compute the diffusion schedule alpha_t for t = 0, ..., T"""
raise NotImplementedError


Expand All @@ -110,22 +115,27 @@ class Noise(ABC, Generic[D]):

@abstractmethod
def schedule(self, alpha: Tensor) -> None:
"""Precompute needed resources based on the diffusion schedule"""
raise NotImplementedError

@abstractmethod
def isotropic(self, *shape: int) -> D:
def isotropic(self, shape: tuple[int, ...]) -> D:
"""Compute the isotropic distribution q(x_T)"""
raise NotImplementedError

@abstractmethod
def prior(self, x: Tensor, t: Tensor) -> D:
"""Compute the prior distribution q(x_t | x_0)"""
raise NotImplementedError

@abstractmethod
def posterior(self, x: Tensor, z: Tensor, t: Tensor) -> D:
"""Compute the posterior distribution q(x_{t-1} | x_t, x_0)"""
raise NotImplementedError

@abstractmethod
def approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> D:
"""Compute the approximate posterior distribution p(x_{t-1} | x_t)"""
raise NotImplementedError


Expand All @@ -136,6 +146,9 @@ class Net(ABC, nn.Module):
def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor:
raise NotImplementedError

def __or__(self, module: nn.Module) -> "Net":
return Sequential(self, module) # type: ignore


class Guidance(ABC):
pass
Expand All @@ -147,9 +160,6 @@ class Loss(ABC, Generic[D]):
def compute(self, batch: Batch[D]) -> Tensor:
raise NotImplementedError

def __call__(self, batch: Batch[D]) -> Tensor:
return self.compute(batch)

def __mul__(self, factor: float) -> "Mul[D]":
return Mul(factor, self)

Expand Down
16 changes: 7 additions & 9 deletions diffusion/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional

import torch
from torch import Tensor, nn
Expand All @@ -17,12 +18,13 @@ def decode(self, x: Tensor) -> Tensor:

@dataclass
class OneHot(Data):
k: int = 2
dimension: Optional[int] = None

def __post_init__(self):
self.i = torch.eye(self.k)
self.i = torch.eye(self.dimension)

def encode(self, w: Tensor) -> Tensor:
self.i = self.i.to(w.device) # TODO change
return self.i[w]

def decode(self, x: Tensor) -> Tensor:
Expand All @@ -31,15 +33,11 @@ def decode(self, x: Tensor) -> Tensor:

@dataclass
class Embedding(Data):
k: int = 2
d: int = 256
count: Optional[int] = None
dimension: Optional[int] = None

def __post_init__(self) -> None:
self.embedding = nn.Embedding(self.k, self.d)

@property
def shape(self) -> tuple[int]:
return (*self.x.shape[1:], self.d)
self.embedding = nn.Embedding(self.count, self.dimension)

def encode(self, w: Tensor) -> Tensor:
return self.embedding(w)
Expand Down
6 changes: 3 additions & 3 deletions diffusion/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def __post_init__(self) -> None:
self.i = torch.eye(self.k, device=self.p.device)

def sample(self) -> tuple[Tensor, None]:
index = torch.multinomial(self.p.view(-1, self.k), 1, True)
return self.i[index.view(*self.p.shape[:-1])], None
c = torch.multinomial(self.p.view(-1, self.k), 1, True)
return self.i[c.view(*self.p.shape[:-1])], None

def nll(self, x: Tensor) -> Tensor:
return -(self.p * x).sum(-1).log()
return -((self.p * x).sum(-1) + 1e-6).log()

def dkl(self, other: Self) -> Tensor:
p1, p2 = self.p + 1e-6, other.p + 1e-6
Expand Down
14 changes: 4 additions & 10 deletions diffusion/guidance.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

__all__ = ["Base", "ClassifierFree"]
from .base import Guidance


class Base(ABC):

@abstractmethod
def __init__(self):
raise NotImplementedError
__all__ = ["ClassifierFree"]


@dataclass
class ClassifierFree(Base):
class ClassifierFree(Guidance):
dropout: float
weight: float
scale: float
16 changes: 9 additions & 7 deletions diffusion/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Literal
from typing import Callable, Literal, TypeVar

import torch
from torch import Tensor
Expand All @@ -8,18 +8,20 @@

__all__ = ["Lambda", "Simple", "VLB"]

D = TypeVar("D", bound=Distribution)


@dataclass
class Lambda(Loss[Distribution]):
function: Callable[[Batch[Distribution]], Tensor]
class Lambda(Loss[D]):
function: Callable[[Batch[D]], Tensor]

def compute(self, batch: Batch[Distribution]) -> Tensor:
def compute(self, batch: Batch[D]) -> Tensor:
return self.function(batch)


@dataclass
class Simple(Loss[Distribution]):
parameter: Literal["x", "epsilon"] = "epsilon"
parameter: Literal["x", "epsilon"] = "x"
index = 0

def compute(self, batch: Batch[Distribution]) -> Tensor:
Expand All @@ -29,5 +31,5 @@ def compute(self, batch: Batch[Distribution]) -> Tensor:
class VLB(Loss[Distribution]):

def compute(self, batch: Batch[Distribution]) -> Tensor:
t = batch.t.view(-1, *(1,) * (batch.x.ndim - 2))
return batch.q.dkl(batch.p).where(t > 1, batch.p.nll(batch.x)).sum()
t = batch.t.view(-1, *(1,) * (batch.x.ndim - 1))
return batch.q.dkl(batch.p).where(t > 1, batch.p.nll(batch.x)).mean()
Loading

0 comments on commit 08a3bfd

Please sign in to comment.