Skip to content

Commit

Permalink
make aim optional, as seeing a segmentation fault. add ability to tur…
Browse files Browse the repository at this point in the history
…n off strict loading of load_state_dict for GAN
  • Loading branch information
lucidrains committed Apr 20, 2022
1 parent ee9fff6 commit 72d6a98
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 14 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,16 @@ Also one flag to use `--multi-gpus`

[Aim](https://github.com/aimhubio/aim) is an open-source experiment tracker that logs your training runs, enables a beautiful UI to compare them and an API to query them programmatically.

You can specify Aim logs directory with `--aim_repo` flag, otherwise logs will be stored in the current directory
First you need to install `aim` with `pip`

```bash
$ lightweight_gan --data ./path/to/images --image-size 512 --aim_repo ./path/to/logs/
$ pip install aim
```

Next, you can specify Aim logs directory with `--aim_repo` flag, otherwise logs will be stored in the current directory

```bash
$ lightweight_gan --data ./path/to/images --image-size 512 --use-aim --aim_repo ./path/to/logs/
```

Execute `aim up --repo ./path/to/logs/` to run Aim UI on your server.
Expand Down
8 changes: 5 additions & 3 deletions lightweight_gan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ def train_from_folder(
seed = 42,
amp = False,
show_progress = False,
use_aim = True,
use_aim = False,
aim_repo = None,
aim_run_hash = None
aim_run_hash = None,
load_strict = True
):
num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)

Expand Down Expand Up @@ -149,7 +150,8 @@ def train_from_folder(
calculate_fid_every = calculate_fid_every,
calculate_fid_num_images = calculate_fid_num_images,
clear_fid_cache = clear_fid_cache,
amp = amp
amp = amp,
load_strict = load_strict
)

if generate:
Expand Down
21 changes: 15 additions & 6 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import aim
import json
import multiprocessing
from random import random
Expand Down Expand Up @@ -961,6 +960,7 @@ def __init__(
use_aim = True,
aim_repo = None,
aim_run_hash = None,
load_strict = True,
*args,
**kwargs
):
Expand Down Expand Up @@ -1038,14 +1038,23 @@ def __init__(

self.syncbatchnorm = is_ddp

self.load_strict = load_strict

self.amp = amp
self.G_scaler = GradScaler(enabled = self.amp)
self.D_scaler = GradScaler(enabled = self.amp)

self.run = None
self.hparams = hparams

if self.is_main and use_aim:
self.run = aim.Run(run_hash=aim_run_hash, repo=aim_repo)
try:
import aim
self.aim = aim
except ImportError:
print('unable to import aim experiment tracker - please run `pip install aim` first')

self.run = self.aim.Run(run_hash=aim_run_hash, repo=aim_repo)
self.run['hparams'] = hparams

@property
Expand Down Expand Up @@ -1347,7 +1356,7 @@ def image_to_pil(image):
aim_images = []
for image in images:
im = image_to_pil(image)
aim_images.append(aim.Image(im, caption=f'#{idx}'))
aim_images.append(self.aim.Image(im, caption=f'#{idx}'))

self.run.track(value=aim_images, name='generated',
step=self.steps,
Expand All @@ -1362,7 +1371,7 @@ def image_to_pil(image):
aim_images = []
for idx, image in enumerate(generated_images):
im = image_to_pil(image)
aim_images.append(aim.Image(im, caption=f'#{idx}'))
aim_images.append(self.aim.Image(im, caption=f'#{idx}'))

self.run.track(value=aim_images, name='generated',
step=self.steps,
Expand All @@ -1376,7 +1385,7 @@ def image_to_pil(image):
aim_images = []
for idx, image in enumerate(generated_images):
im = image_to_pil(image)
aim_images.append(aim.Image(im, caption=f'EMA #{idx}'))
aim_images.append(self.aim.Image(im, caption=f'EMA #{idx}'))

self.run.track(value=aim_images, name='generated',
step=self.steps,
Expand Down Expand Up @@ -1597,7 +1606,7 @@ def load(self, num=-1, print_version=True):
print(f"loading from version {load_data['version']}")

try:
self.GAN.load_state_dict(load_data['GAN'])
self.GAN.load_state_dict(load_data['GAN'], strict = self.load_strict)
except Exception as e:
saved_version = load_data['version']
print('unable to load save model. please try downgrading the package to the version specified by the saved model (to do so, just run `pip install lightweight-gan=={saved_version}`')
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.22.1'
__version__ = '0.22.3'
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
'retry',
'torch>=1.10',
'torchvision',
'tqdm',
'aim'
'tqdm'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 72d6a98

Please sign in to comment.