In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [28]:
from fastai import *
from fastai.vision import *
torch.backends.cudnn.benchmark=True
import time

from adamw import AdamW
from scheduler import Scheduler, LRScheduler
from models import *
from loss import TransferLoss
from data import ContentStyleLoader, InputDataset, SimpleDataBunch
from dist import DDP, sum_tensor, reduce_tensor, env_world_size, env_rank

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

import argparse

In [3]:
PATH = Path('../../data')

# MODEL_PATH = Path('/ncluster/models')
# MODEL_PATH.mkdir(parents=True, exist_ok=True)
# save_tag = 'resnet_fastai'

In [4]:
MODEL_PATH = PATH/'models'

In [5]:
# Parsing
def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
#     parser.add_argument('data', metavar='DIR', help='path to dataset')
    parser.add_argument('--phases', type=str, help='Learning rate schedule')
    parser.add_argument('-j', '--workers', default=8, type=int, help='number of data loading workers (default: 8)')
    parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--print-freq', '-p', default=5, type=int, help='print every')
    parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=0, type=int, help='Used for multi-process training')
    return parser

args = get_parser().parse_args([])

In [6]:
is_distributed = env_world_size() > 1
if args.local_rank > 0:
    f = open('/dev/null', 'w')
    sys.stdout = f
    
if is_distributed:
    print('Distributed initializing process group')
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=env_world_size())
    assert(env_world_size() == dist.get_world_size())
    print("Distributed: success (%d/%d)"%(args.local_rank, dist.get_world_size()))


### Data

In [7]:
# size,bs = 96,36
# size,bs = 128,32
size,bs = 256,2

data_norm,data_denorm = normalize_funcs(*imagenet_stats)

In [8]:
# IMAGENET_PATH = PATH/'imagenet-sz/320/train'
# imagenet_files = get_files(IMAGENET_PATH, recurse=True)

COCO_PATH = PATH/'coco/resize'
coco_files = get_files(COCO_PATH, recurse=True)
cont_ds = InputDataset(coco_files)

# Content Data
cont_tds = DatasetTfm(cont_ds, tfms=[crop_pad(size=size, is_random=False), flip_lr(p=0.5)], tfm_y=False, size=size, do_crop=True)
data_sampler = DistributedSampler(cont_tds, num_replicas=env_world_size(), rank=env_rank()) if is_distributed else None
cont_dl = DeviceDataLoader.create(cont_tds, tfms=data_norm, num_workers=8, 
                                   bs=bs, shuffle=(data_sampler is None), sampler=data_sampler)


In [9]:
# Style Data

STYLE_PATH_DTD = PATH/'style/dtd/images'
dtd_files = get_files(STYLE_PATH_DTD, recurse=True)

# STYLE_PATH_PBN = PATH/'style/pbn/train'
# pbn_files = get_files(STYLE_PATH_PBN, recurse=True)

style_ds = InputDataset(dtd_files)
style_tds = DatasetTfm(style_ds, tfms=[crop_pad(size=size, is_random=False), flip_lr(p=0.5)], tfm_y=False, size=size, do_crop=True)
style_dl = DeviceDataLoader.create(style_tds, tfms=data_norm, num_workers=8, bs=1, shuffle=True)

In [10]:
train_dl = ContentStyleLoader(cont_dl, style_dl)
data = SimpleDataBunch(train_dl, MODEL_PATH)

### Callback

In [25]:

class DistributedRecorder(Recorder):
    def on_backward_begin(self, smooth_loss:Tensor, **kwargs:Any)->None:
        if is_distributed:
            metrics = torch.tensor([smooth_loss]).float().cuda()
            smooth_loss = reduce_tensor(metrics).cpu().numpy()
            
        super().on_backward_begin(smooth_loss)
            

In [12]:
@dataclass
class WeightScheduler(Callback):
    "Manage 1-Cycle style training as outlined in Leslie Smith's [paper](https://arxiv.org/pdf/1803.09820.pdf)."
    learn:Learner
    loss_func:TransferLoss
    cont_phases:Collection[Tuple]
    style_phases:Collection[Tuple]

    def steps(self, phases):
        "Build anneal schedule for all of the parameters."
        n_batch = len(self.learn.data.train_dl)
        return [Stepper((start,end),ep*n_batch,annealing_linear) for ep,start,end in phases]

    def on_train_begin(self, n_epochs:int, **kwargs:Any)->None:
        "Initialize our optimization params based on our annealing schedule."
        self.style_scheds = list(reversed(self.steps(self.style_phases)))
        self.cont_scheds = list(reversed(self.steps(self.cont_phases)))
        
        self.cur_style = self.style_scheds.pop()
        self.cur_cont = self.cont_scheds.pop()

    def on_batch_end(self, train, **kwargs:Any)->None:
        "Take one step forward on the annealing schedule for the optim params."
        if train:
            self.loss_func.ct_wgt = self.cur_style.step()
            self.loss_func.st_wgt = self.cur_cont.step()
            
            if self.cur_style.is_done: self.cur_style = self.style_scheds.pop()
            if self.cur_cont.is_done: self.cur_cont = self.cont_scheds.pop()

In [13]:
lr_mult = env_world_size()

## Create models

In [15]:
# Create models
mt = StyleTransformer()
ms = StylePredict.create_resnet()
# ms = StylePredict.create_inception()
m_com = CombinedModel(mt, ms).cuda()

In [16]:
if is_distributed: 
    m_com = DDP(m_com, device_ids=[args.local_rank], output_device=args.local_rank)

In [17]:
opt_func = partial(AdamW, betas=(0.9,0.999), weight_decay=1e-3)

In [18]:
st_wgt = 2.5e9
ct_wgt = 5e2
tva_wgt = 1e-6
st_block_wgts = [1,80,200,5] # 2,3,4,5
c_block = 1 # 1=3
lr_mult = env_world_size()

In [None]:
epochs = 10
style_phases = [(2,1e2,st_wgt*3),(epochs,st_wgt,st_wgt)]
cont_phases = [(2,ct_wgt,ct_wgt),(2,ct_wgt,ct_wgt*3),(epochs,ct_wgt,ct_wgt)]

In [19]:
m_vgg = VGGActivations().cuda()

Layer ids:  [12, 22, 32, 42]


In [20]:
lr_mult = env_world_size()

In [21]:
loss_func = TransferLoss(m_vgg, ct_wgt, st_wgt, st_block_wgts, tva_wgt, data_norm, c_block)

In [22]:
learner = Learner(data, m_com, opt_func=opt_func, loss_func=loss_func)

In [26]:
w_sched = partial(WeightScheduler, loss_func=loss_func, cont_phases=cont_phases, style_phases=style_phases)
learner.callback_fns = [DistributedRecorder, w_sched]

In [27]:
learner.fit_one_cycle(1, 1e-5*lr_mult)

epoch,train_loss,valid_loss
,,


Restarting style


NameError: name 'style_dl' is not defined

### Set params

In [None]:
epochs = 1
log_interval = 50
optimizer = AdamW(m_com.parameters(), lr=1e-5, betas=(0.9,0.999), weight_decay=1e-3)

In [None]:
lr_mult = env_world_size()
scheduler = LRScheduler(optimizer, [{'ep': (0,1),      'lr': (1e-5*lr_mult,5e-4*lr_mult)}, 
                                  {'ep': (1,2),      'lr': (5e-4*lr_mult,1e-5*lr_mult)},
                                  {'ep': (2,epochs), 'lr': (1e-5*lr_mult,1e-7*lr_mult)}])

In [None]:
st_wgt = 2.5e9
ct_wgt = 5e2
tva_wgt = 1e-6
style_block_wgts = [1,80,200,5] # 2,3,4,5
c_block = 1 # 1=3
style_wgts = [st_wgt*l for l in style_block_wgts]

In [None]:
# m_com.load_state_dict(torch.load(MODEL_PATH/f'model_combined_4_imagenet_256.pth'), strict=False)
m_com.load_state_dict(torch.load(MODEL_PATH/f'{save_tag}.pth'), strict=True)

In [None]:

start = time.time()
m_com.train()
style_image_count = 0
for e in range(epochs):
    agg_content_loss = 0.
    agg_style_loss = 0.
    agg_tva_loss = 0.
    count = 0
    batch_tot = len(train_dl)
    for batch_id, (x_con,_) in enumerate(train_dl):
        scheduler.update_lr(e, batch_id, batch_tot)
        
#         if (batch_id % (log_interval*3) == 0):
        try:
            x_style,_ = next(it_style)
            style_image_count += 1
        except:
            it_style = iter(style_dl)
            x_style,_ = next(it_style)
            print('Restarting style')
            style_image_count = 1
            
        with torch.no_grad(): 
            style_batch = x_style.repeat(bs,1,1,1)
            s_out = m_vgg(style_batch)
            style_feat = [s.clone() for s in s_out]
            
            targ_feat = m_vgg(x_con)[c_block].clone()
            
        n_batch = x_con.size(0)
        count += n_batch
        optimizer.zero_grad()
        
        out = m_com(x_con, x_style)
        out,_ = data_norm((out,None))
        inp_feat = m_vgg(out)
        
        closs = [ct_loss(inp_feat[c_block],targ_feat) * ct_wgt]
        sloss = [gram_loss(inp,targ)*wgt for inp,targ,wgt in zip(inp_feat, style_feat, style_wgts) if wgt > 0]
        tvaloss = tva_loss(out) * tva_wgt
        
        total_loss = closs + sloss + [tvaloss]
        total_loss = sum(total_loss)
    
        total_loss.backward()
#         nn.utils.clip_grad_norm_(m_com.m_tran.parameters(), 10)
#         nn.utils.clip_grad_norm_(m_com.m_style.parameters(), 80)
        optimizer.step()
    
        mom = 0.9
        agg_content_loss = agg_content_loss*mom + sum(closs).detach().data*(1-mom)
        agg_style_loss = agg_style_loss*mom + sum(sloss).detach().data*(1-mom)
        agg_tva_loss = agg_tva_loss*mom + tvaloss.detach().data*(1-mom)
        agg_total_loss = (agg_content_loss + agg_style_loss + agg_tva_loss)
        
        if (batch_id + 1) % log_interval == 0:
            time_elapsed = (time.time() - start)/60
            b_tot = len(train_dl)
            mesg = (f"MIN:{time_elapsed:.2f}\tEP[{e+1}]\tB[{batch_id+1:4}/{b_tot}]\t"
                    f"CON:{agg_content_loss:.3f}\tSTYL:{agg_style_loss:.2f}\t"
                    f"TVA:{agg_tva_loss:.2f}\tTOT:{agg_total_loss:.2f}\t"
#                     f"S/CT:{style_image_count:3}/{count:3}"
                   )
            print(mesg)

In [None]:
def eval_imgs(x_con, x_style, idx=0):
    with torch.no_grad(): 
        out = m_com(x_con, x_style)
    fig, axs = plt.subplots(1,3,figsize=(12,4))
    Image(data_denorm(x_con[idx].cpu())).show(axs[0])
    axs[0].set_title('Content')
    axs[1].set_title('Style')
    axs[2].set_title('Transfer')
    Image(data_denorm(x_style[0].cpu())).show(axs[1])
    Image((out[idx].detach().cpu())).show(axs[2])
    

In [None]:
x_con,_ = next(iter(train_dl))
idx=0

In [None]:
x_style,_ = next(iter(style_dl))
eval_imgs(x_con, x_style, idx)

In [None]:
x_style,_ = next(iter(style_dl))
eval_imgs(x_con, x_style, idx)