In [1]:
from fastai.vision.all import *
from fastai.vision.gan import *

The generator

In [2]:
path = Path('datasetBenignas/')

In [3]:
pathI = path/'images'

In [4]:
def get_dls(bs:int, size:int):
  "Generates two `GAN` DataLoaders"
  dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
                   get_items=get_image_files,
                   get_y = lambda x: pathI/x.name,
                   splitter=RandomSplitter(),
                   item_tfms=Resize(size),
                   batch_tfms=[*aug_transforms(max_zoom=2.),
                               Normalize.from_stats(*imagenet_stats)])
  dls = dblock.dataloaders(pathI, bs=bs, path=path)
  dls.c = 3 # For 3 channel image
  return dls

In [5]:
dls_gen = get_dls(1, 64)

In [6]:
wd, y_range, loss_gen = 1e-3, (-3., 3.), MSELossFlat()

In [7]:
bbone = resnet50

def create_gen_learner():
  return unet_learner(dls_gen, bbone, loss_func=loss_gen,blur=True, norm_type=NormType.Weight, self_attention=True,
                  y_range=y_range)

In [8]:
learn_gen = create_gen_learner()

In [9]:
learn_gen.fit_one_cycle(2, pct_start=0.8, wd=wd)


epoch,train_loss,valid_loss,time
0,0.002231,0.002368,05:00
1,0.000198,0.000223,04:55


In [10]:
learn_gen.unfreeze()

In [11]:
learn_gen.fit_one_cycle(3, slice(1e-6,1e-3), wd=wd)

epoch,train_loss,valid_loss,time
0,0.000653,0.000291,05:03
1,0.000112,0.000109,05:03
2,5.2e-05,9.5e-05,05:03


In [12]:
learn_gen.save('gen-pre2-resnet50')

Path('datasetBenignas/models/gen-pre2-resnet50.pth')

In [8]:
name_gen = 'image_gen'
path_gen = path/name_gen
path_gen.mkdir(exist_ok=True)

In [9]:
def save_preds(dl, learn):
  "Save away predictions"
  names = dl.dataset.items
  
  preds,_ = learn.get_preds(dl=dl)
  for i,pred in enumerate(preds):
      dec = dl.after_batch.decode((TensorImage(pred[None]),))[0][0]
      arr = dec.numpy().transpose(1,2,0).astype(np.uint8)
      Image.fromarray(arr).save(path_gen/names[i].name)

In [15]:
dl = dls_gen.train.new(shuffle=False, drop_last=False, 
                       after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])

In [16]:
save_preds(dl, learn_gen)

The critic

In [9]:
path_g = get_image_files(path/name_gen)
path_i = get_image_files(path/'images')
fnames = path_g + path_i

In [10]:
def get_crit_dls(fnames, bs:int, size:int):
  "Generate two `Critic` DataLoaders"
  splits = RandomSplitter(0.1)(fnames)
  dsrc = Datasets(fnames, tfms=[[PILImage.create], [parent_label, Categorize]],
                 splits=splits)
  tfms = [ToTensor(), Resize(size)]
  gpu_tfms = [IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)]
  return dsrc.dataloaders(bs=bs, after_item=tfms, after_batch=gpu_tfms)

In [11]:
dls_crit = get_crit_dls(fnames, bs=1, size=64)

In [12]:
loss_crit = AdaptiveLoss(nn.BCEWithLogitsLoss())

In [13]:
def create_crit_learner(dls, metrics):
  return Learner(dls, gan_critic(), metrics=metrics, loss_func=loss_crit)

In [14]:
learn_crit = create_crit_learner(dls_crit, accuracy_thresh_expand)

In [23]:
learn_crit.fit_one_cycle(6, 1e-3, wd=wd)

epoch,train_loss,valid_loss,accuracy_thresh_expand,time
0,0.693586,0.693627,0.495455,02:09
1,0.693311,0.693104,0.504545,02:08
2,0.692868,0.693206,0.504545,02:08
3,0.693315,0.693131,0.504545,02:08
4,0.693796,0.693031,0.504545,02:08
5,0.692727,0.693,0.504545,02:08


In [24]:
learn_crit.save('critic-pre2-resnet50')

Path('models/critic-pre2-resnet50.pth')

The GAN

In [14]:
dls_crit = get_crit_dls(fnames, bs=1, size=64)

In [15]:
learn_crit = create_crit_learner(dls_crit, metrics=None).load('critic-pre2-resnet50')
learn_crit.to_fp16()

<fastai.learner.Learner at 0x7fc1d0f19be0>

In [16]:
learn_gen = create_gen_learner().load('gen-pre2-resnet50')
learn_gen.to_fp16()

<fastai.learner.Learner at 0x7fc1d0b5b4e0>

In [17]:
class GANDiscriminativeLR(Callback):
    "`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
    def __init__(self, mult_lr=5.): self.mult_lr = mult_lr

    def begin_batch(self):
        "Multiply the current lr if necessary."
        if not self.learn.gan_trainer.gen_mode and self.training: 
            self.learn.opt.set_hyper('lr', learn.opt.hypers[0]['lr']*self.mult_lr)

    def after_batch(self):
        "Put the LR back to its value if necessary."
        if not self.learn.gan_trainer.gen_mode: self.learn.opt.set_hyper('lr', learn.opt.hypers[0]['lr']/self.mult_lr)

In [18]:
switcher = AdaptiveGANSwitcher(critic_thresh=.65)

In [19]:
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,
                                 opt_func=partial(Adam, mom=0.), cbs=GANDiscriminativeLR(mult_lr=5.))
learn.to_fp16()

<fastai.vision.gan.GANLearner at 0x7fc1d02db550>

In [20]:
lr = 1e-4

In [None]:
learn.fit(10, lr, wd=wd)

  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")


epoch,train_loss,valid_loss,gen_loss,crit_loss,time


In [None]:
learn.save('gan-resnet50')