In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [6]:
from fastai import *
from fastai.vision import *
torch.backends.cudnn.benchmark=True
import time

from adamw import AdamW
from scheduler import Scheduler, LRScheduler
from models import *
from loss import TransferLoss
from data import ContentStyleLoader, InputDataset, SimpleDataBunch
from dist import DDP, sum_tensor, reduce_tensor, env_world_size, env_rank

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

import argparse

In [7]:
PATH = Path('../../data')
if not PATH.exists():
    PATH = Path('../data')
# PATH = Path('/ncluster/models/resnet_test.pth')
# MODEL_PATH = Path('/ncluster/models')
# MODEL_PATH.mkdir(exist_ok=True)
# save_tag = 'resnet_test'

In [8]:
# Parsing
def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
#     parser.add_argument('data', metavar='DIR', help='path to dataset')
    parser.add_argument('--phases', type=str, help='Learning rate schedule')
    parser.add_argument('-j', '--workers', default=8, type=int, help='number of data loading workers (default: 8)')
    parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--print-freq', '-p', default=5, type=int, help='print every')
    parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=0, type=int, help='Used for multi-process training')
    return parser

args = get_parser().parse_args([])

In [9]:
is_distributed = env_world_size() > 1
if args.local_rank > 0:
    f = open('/dev/null', 'w')
    sys.stdout = f
    
if is_distributed:
    print('Distributed initializing process group')
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=env_world_size())
    assert(env_world_size() == dist.get_world_size())
    print("Distributed: success (%d/%d)"%(args.local_rank, dist.get_world_size()))


### Data

In [10]:
# size,bs = 96,36
# size,bs = 128,32
size,bs = 256,2

data_norm,data_denorm = normalize_funcs(*imagenet_stats)

In [11]:
# IMAGENET_PATH = PATH/'imagenet-sz/320/train'
# imagenet_files = get_files(IMAGENET_PATH, recurse=True)

COCO_PATH = PATH/'coco/resize'
coco_files = get_files(COCO_PATH, recurse=True)
cont_ds = InputDataset(coco_files)

# Content Data
cont_tds = DatasetTfm(cont_ds, tfms=[crop_pad(size=size, is_random=False), flip_lr(p=0.5)], tfm_y=False, size=size, do_crop=True)
data_sampler = DistributedSampler(cont_tds, num_replicas=env_world_size(), rank=env_rank()) if is_distributed else None
cont_dl = DeviceDataLoader.create(cont_tds, tfms=data_norm, num_workers=8, 
                                   bs=bs, shuffle=(data_sampler is None), sampler=data_sampler)


In [12]:
# Style Data

STYLE_PATH_DTD = PATH/'style/dtd/images'
dtd_files = get_files(STYLE_PATH_DTD, recurse=True)

# STYLE_PATH_PBN = PATH/'style/pbn/train'
# pbn_files = get_files(STYLE_PATH_PBN, recurse=True)

style_ds = InputDataset(dtd_files)
style_tds = DatasetTfm(style_ds, tfms=[crop_pad(size=size, is_random=False), flip_lr(p=0.5)], tfm_y=False, size=size, do_crop=True)
style_dl = DeviceDataLoader.create(style_tds, tfms=data_norm, num_workers=8, bs=1, shuffle=True)

In [13]:
train_dl = ContentStyleLoader(cont_dl, style_dl, repeat_xy=False)

## Create models

In [14]:
# Create models
mt = StyleTransformer()
ms = StylePredict.create_resnet()
# ms = StylePredict.create_inception()
m_com = CombinedModel(mt, ms).cuda()
if is_distributed: 
    m_com = DDP(m_com, device_ids=[args.local_rank], output_device=args.local_rank)

### Set params

In [15]:
epochs = 3
log_interval = 50
optimizer = AdamW(m_com.parameters(), lr=1e-5, betas=(0.9,0.999), weight_decay=1e-3)

In [16]:
lr_mult = env_world_size()
scheduler = LRScheduler(optimizer, [{'ep': (0,1),      'lr': (1e-5*lr_mult,5e-4*lr_mult)}, 
                                  {'ep': (1,2),      'lr': (5e-4*lr_mult,1e-5*lr_mult)},
                                  {'ep': (2,epochs), 'lr': (1e-5*lr_mult,1e-7*lr_mult)}])

In [17]:
st_wgt = 2.5e9
ct_wgt = 5e2
tva_wgt = 1e-6
st_block_wgts = [1,80,200,5] # 2,3,4,5
c_block = 1 # 1=3
lr_mult = env_world_size()

In [18]:
m_vgg = VGGActivations().cuda()

Layer ids:  [12, 22, 32, 42]


In [19]:
m_loss = TransferLoss(m_vgg, ct_wgt, st_wgt, st_block_wgts, tva_wgt, data_norm, c_block, sum_loss=False)

In [20]:
# m_com.load_state_dict(torch.load(MODEL_PATH/f'model_combined_4_imagenet_256.pth'), strict=False)
# m_com.load_state_dict(torch.load(MODEL_PATH/f'{save_tag}.pth'), strict=True)

In [21]:

start = time.time()
m_com.train()
style_image_count = 0
for e in range(epochs):
    agg_content_loss = 0.
    agg_style_loss = 0.
    agg_tva_loss = 0.
    count = 0
    batch_tot = len(train_dl)
    for batch_id, (x_con,x_style) in enumerate(train_dl):
        scheduler.update_lr(e, batch_id, batch_tot)
            
        n_batch = x_con.size(0)
        count += n_batch
        optimizer.zero_grad()
        
        out = m_com(x_con, x_style)
        out,_ = data_norm((out,None))
        
        closs, sloss, tvaloss = m_loss(out, x_con, x_style)
        
        total_loss = closs + sloss + [tvaloss]
        total_loss = sum(total_loss)
    
        total_loss.backward()
        m_clip = m_com.module if is_distributed else m_com
        nn.utils.clip_grad_norm_(m_clip.m_tran.parameters(), 10)
        nn.utils.clip_grad_norm_(m_clip.m_style.parameters(), 100)
        optimizer.step()
    
        mom = 0.9
        agg_content_loss = agg_content_loss*mom + sum(closs).detach().data*(1-mom)
        agg_style_loss = agg_style_loss*mom + sum(sloss).detach().data*(1-mom)
        agg_tva_loss = agg_tva_loss*mom + tvaloss.detach().data*(1-mom)
        agg_total_loss = (agg_content_loss + agg_style_loss + agg_tva_loss)

        if is_distributed: # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
            metrics = torch.tensor([agg_content_loss, agg_style_loss, agg_tva_loss, agg_total_loss]).float().cuda()
            agg_content_loss, agg_style_loss, agg_tva_loss, agg_total_loss = reduce_tensor(metrics).cpu().numpy()
        
        if (batch_id + 1) % log_interval == 0:
            time_elapsed = (time.time() - start)/60
            mesg = (f"MIN:{time_elapsed:.2f}\tEP[{e+1}]\tB[{batch_id+1:4}/{batch_tot}]\t"
                    f"CON:{agg_content_loss:.3f}\tSTYL:{agg_style_loss:.2f}\t"
                    f"TVA:{agg_tva_loss:.2f}\tTOT:{agg_total_loss:.2f}\t"
                    f"S/CT:{style_image_count:3}/{count:3}"
                   )
            print(mesg)

        save_interval = 1000
        if (args.local_rank == 0) and (batch_id+1) % save_interval == 0:
            if args.save:
                print('Saving model: ', args.save)
                save_path = Path(args.save).expanduser()
                save_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save(m_com.state_dict(), save_path)

                ep_save = save_path.with_name(f'{save_path.stem}_{e}').with_suffix(save_path.suffix)

                print('Saving epoch checkpoint: ', ep_save)
                torch.save(m_com.state_dict(), ep_save)

Changing LR from 1e-05 to 1.0008285004142502e-05
MIN:0.16	EP[1]	B[  50/59143]	CON:10.052	STYL:103.82	TVA:0.07	TOT:113.94	S/CT:  0/100
MIN:0.22	EP[1]	B[ 100/59143]	CON:10.165	STYL:117.69	TVA:0.07	TOT:127.92	S/CT:  0/200
MIN:0.28	EP[1]	B[ 150/59143]	CON:9.723	STYL:86.59	TVA:0.06	TOT:96.38	S/CT:  0/300
MIN:0.34	EP[1]	B[ 200/59143]	CON:10.130	STYL:87.29	TVA:0.06	TOT:97.48	S/CT:  0/400
MIN:0.40	EP[1]	B[ 250/59143]	CON:9.596	STYL:93.95	TVA:0.06	TOT:103.61	S/CT:  0/500
MIN:0.46	EP[1]	B[ 300/59143]	CON:9.983	STYL:69.37	TVA:0.07	TOT:79.42	S/CT:  0/600
MIN:0.53	EP[1]	B[ 350/59143]	CON:9.978	STYL:70.44	TVA:0.07	TOT:80.48	S/CT:  0/700
MIN:0.59	EP[1]	B[ 400/59143]	CON:10.310	STYL:96.92	TVA:0.07	TOT:107.30	S/CT:  0/800
MIN:0.65	EP[1]	B[ 450/59143]	CON:10.513	STYL:74.34	TVA:0.08	TOT:84.93	S/CT:  0/900


Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/ubuntu/anaconda3/envs/fastai/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

In [None]:
def eval_imgs(x_con, x_style, model=m_com, idx=0):
    with torch.no_grad(): 
        out = m_com(x_con, x_style)
    fig, axs = plt.subplots(1,3,figsize=(12,4))
    Image(data_denorm(x_con[idx].cpu())).show(axs[0])
    axs[0].set_title('Content')
    axs[1].set_title('Style')
    axs[2].set_title('Transfer')
    Image(data_denorm(x_style[0].cpu())).show(axs[1])
    Image((out[idx].detach().cpu())).show(axs[2])
    

In [None]:
x_con,_ = next(iter(train_dl))
idx=0

In [None]:
x_style,_ = next(iter(style_dl))
eval_imgs(x_con, x_style, idx=idx)

In [None]:
x_style,_ = next(iter(style_dl))
eval_imgs(x_con, x_style, idx=idx)