Skip to content

Commit

Permalink
Allow str device
Browse files Browse the repository at this point in the history
  • Loading branch information
cabralpinto committed Aug 28, 2023
1 parent 8b8a6a6 commit 820f224
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 27 deletions.
27 changes: 5 additions & 22 deletions diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class Model(Generic[D]):
data: Data
schedule: Schedule
noise: Noise[D]
loss: Loss[D]
net: Net
loss: Loss[D]
time: Time = field(default_factory=Discrete)
guidance: Optional[Guidance] = None # TODO remove hardcoding
optimizer: Optional[Optimizer | Callable[..., Optimizer]] = None
device: torch.device = torch.device("cpu")
device: str | torch.device = torch.device("cpu")
compile: bool = True

@torch.no_grad()
Expand All @@ -43,28 +43,11 @@ def __post_init__(self):
self.net = self.net.to(self.device)
for name, value in vars(self.data).items():
if isinstance(value, nn.Module):
setattr(self.data, name, value.to(self.device))
setattr(self.data, name, value.to(self.device))
if self.compile and sys.version_info < (3, 11):
self.net = torch.compile(self.net) # type: ignore[union-attr]

@torch.no_grad()
def load(self, path: Path | str):
state = torch.load(path)
self.net.load_state_dict(state["net"])
for name, dict in state["data"].items():
getattr(self.data, name).load_state_dict(dict)

@torch.no_grad()
def save(self, path: Path | str):
state = {
"net": self.net.state_dict(),
"data": {
name: value.state_dict()
for name, value in vars(self.data).items()
if isinstance(value, nn.Module)
}
}
torch.save(state, path)
if isinstance(self.device, str):
self.device = torch.device(self.device)

@torch.enable_grad()
def train(self, epochs: int = 1, progress: bool = True) -> Iterator[float]:
Expand Down
2 changes: 1 addition & 1 deletion examples/conditional-diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
net=UNet(channels=(1, 64, 128, 256), labels=10),
guidance=ClassifierFree(dropout=0.1, strength=2),
loss=Simple(parameter="epsilon"),
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
device="cuda" if torch.cuda.is_available() else "cpu",
)

if (output / "model.pt").exists():
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding-diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
noise=Gaussian(parameter="x", variance="fixed"),
loss=Simple(parameter="x"),
net=Transformer(input=32, width=1024, depth=16, heads=16),
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
device="cuda" if torch.cuda.is_available() else "cpu",
)

if (output / "model.pt").exists():
Expand Down
2 changes: 1 addition & 1 deletion examples/text-diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
noise=Absorbing(len(v)),
loss=VLB() + 1e-2 * Lambda[Cat](lambda batch: Cat(batch.hat[0]).nll(batch.x).sum()),
net=Transformer(input=len(v), width=1024, depth=16, heads=16) | Softmax(3),
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
device="cuda" if torch.cuda.is_available() else "cpu",
)

if (output / "model.pt").exists():
Expand Down
2 changes: 1 addition & 1 deletion examples/transformer-diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
noise=Gaussian(parameter="epsilon", variance="fixed"),
net=Transformer(input=x.shape[2], width=768, depth=12, heads=12),
loss=Simple(parameter="epsilon"),
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
device="cuda" if torch.cuda.is_available() else "cpu",
)

if (output / "model.pt").exists():
Expand Down
2 changes: 1 addition & 1 deletion examples/unconditional-diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
noise=Gaussian(parameter="epsilon", variance="fixed"),
net=UNet(channels=(1, 64, 128, 256)),
loss=Simple(parameter="epsilon"),
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
device="cuda" if torch.cuda.is_available() else "cpu",
)

if (output / "model.pt").exists():
Expand Down

0 comments on commit 820f224

Please sign in to comment.