Skip to content

Commit

Permalink
fix: remove torch script (#139)
Browse files Browse the repository at this point in the history
* fix: remove torch script

* fix: remove torch script

* fix: remove torch script

* fix: remove torch script

* fix: remove torch script
  • Loading branch information
hanxiao committed Aug 4, 2022
1 parent 640211e commit 6871ccc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
42 changes: 21 additions & 21 deletions discoart/nn/make_cutouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,8 @@
from torchvision import transforms as T
from torchvision.transforms import functional as TF

augment = torch.jit.script(
torch.nn.Sequential(
*[
T.RandomHorizontalFlip(p=0.5),
T.RandomAffine(
degrees=10,
translate=(0.05, 0.05),
interpolation=T.InterpolationMode.BILINEAR,
),
T.RandomGrayscale(p=0.1),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
)


class MakeCutoutsDango(nn.Module):
class MakeCutouts(nn.Module):
def __init__(
self,
cut_size,
Expand All @@ -40,9 +21,28 @@ def __init__(
self.InnerCrop = InnerCrop
self.IC_Size_Pow = IC_Size_Pow
self.IC_Grey_P = IC_Grey_P
self.augment = T.Compose(
[
T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(
degrees=10,
translate=(0.05, 0.05),
interpolation=T.InterpolationMode.BILINEAR,
),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.1),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)

def forward(self, input):
return torch.cat([augment(c) for c in self._cut_generator(input)])
return torch.cat([self.augment(c) for c in self._cut_generator(input)])

def _cut_generator(self, input):
gray = T.Grayscale(3)
Expand Down
7 changes: 4 additions & 3 deletions discoart/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
is_jupyter,
)
from .nn.losses import spherical_dist_loss, tv_loss, range_loss
from .nn.make_cutouts import MakeCutoutsDango
from .nn.make_cutouts import MakeCutouts
from .nn.sec_diff import alpha_sigma_to_t
from .nn.transform import symmetry_transformation_fn
from .persist import _sample_thread, _persist_thread, _save_progress_thread
Expand Down Expand Up @@ -55,7 +55,7 @@ def do_run(args, models, device, events) -> 'DocumentArray':

schedule_table = _get_schedule_table(args)

from .nn.perlin_noises import create_perlin_noise, regen_perlin
from .nn.perlin_noises import regen_perlin

skip_steps = args.skip_steps

Expand Down Expand Up @@ -220,7 +220,7 @@ def cond_fn(x, t, **kwargs):
else:
continue

cuts = MakeCutoutsDango(
cuts = MakeCutouts(
model_stat['input_resolution'],
Overview=scheduler.cut_overview,
InnerCrop=scheduler.cut_innercut,
Expand All @@ -229,6 +229,7 @@ def cond_fn(x, t, **kwargs):
)

for _ in range(scheduler.cutn_batches):

clip_in = cuts(x_in.add(1).div(2))

if args.visualize_cuts and not is_cuts_visualized:
Expand Down

0 comments on commit 6871ccc

Please sign in to comment.