In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai import *
from fastai.vision import *
torch.backends.cudnn.benchmark=True
import time

from adamw import AdamW
from scheduler import Scheduler
from models import *

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

import argparse

In [3]:
PATH = Path('../../data')
MODEL_PATH = PATH/'models'
MODEL_PATH.mkdir(exist_ok=True)
save_tag = 'imagenet_magenta'

In [4]:
# Parsing
def get_parser():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
#     parser.add_argument('data', metavar='DIR', help='path to dataset')
    parser.add_argument('--phases', type=str, help='Learning rate schedule')
    parser.add_argument('-j', '--workers', default=8, type=int, help='number of data loading workers (default: 8)')
    parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--print-freq', '-p', default=5, type=int, help='print every')
    parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=0, type=int, help='Used for multi-process training')
    return parser

# Distributed
def reduce_tensor(tensor): return sum_tensor(tensor)/env_world_size()
def sum_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    return rt
def env_world_size(): return int(os.environ.get('WORLD_SIZE', 1))
def env_rank(): return int(os.environ.get('RANK',0))

args = get_parser().parse_args([])

In [5]:
is_distributed = env_world_size() > 1
if args.local_rank > 0:
    f = open('/dev/null', 'w')
    sys.stdout = f
    
if is_distributed:
    print('Distributed initializing process group')
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=env_world_size())
    assert(env_world_size() == dist.get_world_size())
    print("Distributed: success (%d/%d)"%(args.local_rank, dist.get_world_size()))


In [6]:
fp16 = True
mean,std = imagenet_stats
if fp16: 
    mean = tensor(mean).half()
    std = tensor(std).half()

In [10]:
# IMAGENET_PATH = PATH/'imagenet-sz/160/train'
# train_ds = ImageClassificationDataset.from_folder(IMAGENET_PATH)
COCO_PATH = PATH/'coco/resize'
train_ds = ImageClassificationDataset.from_folder(COCO_PATH)

In [12]:
# size,bs = 96,36
size,bs = 128,32
# size,bs = 256,20

# Content Data
data_norm,data_denorm = normalize_funcs(mean, std)
train_tds = DatasetTfm(train_ds, tfms=[crop_pad(size=size, is_random=False)], tfm_y=False, size=size, do_crop=True)

data_sampler = DistributedSampler(train_tds, num_replicas=env_world_size(), rank=env_rank()) if is_distributed else None
train_dl = DeviceDataLoader.create(train_tds, tfms=data_norm, num_workers=8, bs=bs, shuffle=(data_sampler is None), sampler=data_sampler)
train_dl = DeviceDataLoader.create(train_tds, tfms=[to_half, data_norm], num_workers=8, bs=bs, shuffle=(data_sampler is None), sampler=data_sampler)

# Style Data

# STYLE_PATH = PATH/'style/dtd/images'
STYLE_PATH = PATH/'style/dtd/subset'
style_ds = ImageClassificationDataset.from_folder(STYLE_PATH)

# STYLE_PATH = PATH/'style/pbn/train'
# style_ds = ImageClassificationDataset.from_single_folder(STYLE_PATH, ['train'])

style_tds = DatasetTfm(style_ds, tfms=[crop_pad(size=size, is_random=False)], tfm_y=False, size=size, do_crop=True)
style_dl = DeviceDataLoader.create(style_tds, tfms=data_norm, num_workers=8, bs=1, shuffle=True)
style_dl = DeviceDataLoader.create(style_tds, tfms=[to_half, data_norm], num_workers=8, bs=1, shuffle=True)

In [10]:
# if fp16:
#     train_dl.add_tfm(to_half)
#     style_dl.add_tfm(to_half)

### Loss

In [13]:
# Loss Functions
# losses
def ct_loss(input, target): return F.mse_loss(input,target)

def gram(input):
        b,c,h,w = input.size()
        x = input.view(b, c, -1)
        return torch.bmm(x, x.transpose(1,2))/(c*h*w)

def gram_loss(input, target): return F.mse_loss(gram(input), gram(target))

def tva_loss(y):
    w_var = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:]))
    h_var = torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))
    return w_var + h_var


### Distributed

In [14]:
from torch.utils.data.distributed import DistributedSampler

def reduce_tensor(tensor): return sum_tensor(tensor)/env_world_size()
def sum_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    return rt

def env_world_size(): return int(os.environ.get('WORLD_SIZE', 1))
def env_rank(): return int(os.environ.get('RANK',0))

In [15]:
# losses
def ct_loss(input, target): return F.mse_loss(input,target)

def gram(input):
        b,c,h,w = input.size()
        x = input.view(b, c, -1)
        return torch.bmm(x, x.transpose(1,2))/(c*h*w)

def gram_loss(input, target): return F.mse_loss(gram(input), gram(target))

def tva_loss(y):
    w_var = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:]))
    h_var = torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))
    return w_var + h_var

## Create models

In [16]:
# Create models
mt = StyleTransformer()
ms = StylePredict.create_inception()
m_com = CombinedModel(mt, ms).cuda()

In [17]:
m_vgg = VGGActivations().cuda()

if is_distributed: 
    m_com = DistributedDataParallel(m_com, device_ids=[args.local_rank], output_device=args.local_rank)
    m_vgg = DistributedDataParallel(m_vgg, device_ids=[args.local_rank], output_device=args.local_rank)
    
    

Layer ids:  [12, 22, 32, 42]


### FP16

In [18]:
from fp16_utils import FP16_Optimizer, network_to_half, BN_convert_float

### Set params

In [19]:
epochs = 3
log_interval = 50
optimizer = AdamW(m_com.parameters(), lr=1e-5, betas=(0.9,0.999), weight_decay=1e-3)

In [20]:
if fp16:
    m_com = BN_convert_float(m_com.half())
    m_vgg = BN_convert_float(m_vgg.half())
    optimizer = FP16_Optimizer(optimizer,
                               static_loss_scale=256,
                               dynamic_loss_scale=True);
    

FP16_Optimizer processing param group 0:
FP16_Optimizer received torch.cuda.HalfTensor with torch.Size([32, 3, 9, 9])
FP16_Optimizer received torch.cuda.FloatTensor with torch.Size([32])
FP16_Optimizer received torch.cuda.FloatTensor with torch.Size([32])
FP16_Optimizer received torch.cuda.HalfTensor with torch.Size([64, 32, 3, 3])
FP16_Optimizer received torch.cuda.FloatTensor with torch.Size([64])
FP16_Optimizer received torch.cuda.FloatTensor with torch.Size([64])
FP16_Optimizer received torch.cuda.HalfTensor with torch.Size([128, 64, 3, 3])
FP16_Optimizer received torch.cuda.FloatTensor with torch.Size([128])
FP16_Optimizer received torch.cuda.FloatTensor with torch.Size([128])
FP16_Optimizer received torch.cuda.HalfTensor with torch.Size([128, 128, 3, 3])
FP16_Optimizer received torch.cuda.FloatTensor with torch.Size([128])
FP16_Optimizer received torch.cuda.FloatTensor with torch.Size([128])
FP16_Optimizer received torch.cuda.HalfTensor with torch.Size([128, 128, 3, 3])
FP16_Opti

In [21]:
lr_mult = env_world_size()
scheduler = Scheduler(optimizer, [{'ep': (0,1),      'lr': (1e-5*lr_mult,5e-4*lr_mult)}, 
                                  {'ep': (1,2),      'lr': (5e-4*lr_mult,1e-5*lr_mult)},
                                  {'ep': (2,epochs), 'lr': (1e-5*lr_mult,1e-7*lr_mult)}])

In [22]:
style_wgts = [i*1e9 for i in [5,50,5,.5]] # 2,3,4,5
c_block = 1 # 1=3
ct_wgt = 1e3
tva_wgt = 5e-6

In [23]:
# m_com.load_state_dict(torch.load(MODEL_PATH/f'model_combined_4_imagenet_256.pth'), strict=False)
# m_com.load_state_dict(torch.load(MODEL_PATH/f'model_combined_{save_tag}.pth'), strict=True)

In [28]:
x_con,_ = next(iter(train_dl))

In [26]:

start = time.time()
m_com.train()
style_image_count = 0
for e in range(epochs):
    agg_content_loss = 0.
    agg_style_loss = 0.
    agg_tva_loss = 0.
    count = 0
    batch_tot = len(train_dl)
    for batch_id, (x_con,_) in enumerate(train_dl):
        scheduler.update_lr(e, batch_id, batch_tot)
        
#         if (batch_id % (log_interval*3) == 0):
        try:
            x_style,_ = next(it_style)
            style_image_count += 1
        except:
            it_style = iter(style_dl)
            x_style,_ = next(it_style)
            print('Restarting style')
            style_image_count = 1
            
        with torch.no_grad(): 
            style_batch = x_style.repeat(bs,1,1,1)
            s_out = m_vgg(style_batch)
            style_feat = [s.clone() for s in s_out]
            
            targ_feat = m_vgg(x_con)[c_block].clone()
            
        n_batch = x_con.size(0)
        count += n_batch
        optimizer.zero_grad()
        
        out = m_com(x_con, x_style)
        out,_ = data_norm((out,None))
        inp_feat = m_vgg(out)
        
        closs = [ct_loss(inp_feat[c_block],targ_feat) * ct_wgt]
        sloss = [gram_loss(inp,targ)*wgt for inp,targ,wgt in zip(inp_feat, style_feat, style_wgts) if wgt > 0]
        tvaloss = tva_loss(out) * tva_wgt
        
        total_loss = closs + sloss + [tvaloss]
        total_loss = sum(total_loss)
    
        if fp16:
            optimizer.backward(total_loss)
        else:
            total_loss.backward()
            
#         nn.utils.clip_grad_norm_(m_com.m_tran.parameters(), 10)
#         nn.utils.clip_grad_norm_(m_com.m_style.parameters(), 80)
        optimizer.step()
    
        mom = 0.9
        agg_content_loss = agg_content_loss*mom + sum(closs).detach().data*(1-mom)
        agg_style_loss = agg_style_loss*mom + sum(sloss).detach().data*(1-mom)
        agg_tva_loss = agg_tva_loss*mom + tvaloss.detach().data*(1-mom)
        agg_total_loss = (agg_content_loss + agg_style_loss + agg_tva_loss)
        
        if (batch_id + 1) % log_interval == 0:
            time_elapsed = (time.time() - start)/60
            b_tot = len(train_dl)
            mesg = (f"MIN:{time_elapsed:.2f}\tEP[{e+1}]\tB[{batch_id+1:4}/{b_tot}]\t"
                    f"CON:{agg_content_loss:.3f}\tSTYL:{agg_style_loss:.2f}\t"
                    f"TVA:{agg_tva_loss:.2f}\tTOT:{agg_total_loss:.2f}\t"
#                     f"S/CT:{style_image_count:3}/{count:3}"
                   )
            print(mesg)

RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #3 'weight'

In [25]:
%debug

> [0;32m/home/ubuntu/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/autograd/__init__.py[0m(90)[0;36mbackward[0;34m()[0m
[0;32m     88 [0;31m    Variable._execution_engine.run_backward(
[0m[0;32m     89 [0;31m        [0mtensors[0m[0;34m,[0m [0mgrad_tensors[0m[0;34m,[0m [0mretain_graph[0m[0;34m,[0m [0mcreate_graph[0m[0;34m,[0m[0;34m[0m[0m
[0m[0;32m---> 90 [0;31m        allow_unreachable=True)  # allow_unreachable flag
[0m[0;32m     91 [0;31m[0;34m[0m[0m
[0m[0;32m     92 [0;31m[0;34m[0m[0m
[0m
ipdb> up
> [0;32m/home/ubuntu/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/tensor.py[0m(96)[0;36mbackward[0;34m()[0m
[0;32m     94 [0;31m                [0mproducts[0m[0;34m.[0m [0mDefaults[0m [0mto[0m[0;31m [0m[0;31m`[0m[0;31m`[0m[0;32mFalse[0m[0;31m`[0m[0;31m`[0m[0;34m.[0m[0;34m[0m[0m
[0m[0;32m     95 [0;31m        """
[0m[0;32m---> 96 [0;31m        [0mtorch[0m[0;34m.[0m[0mautograd[0m[0;3

In [18]:

start = time.time()
m_com.train()
style_image_count = 0
for e in range(epochs):
    agg_content_loss = 0.
    agg_style_loss = 0.
    agg_tva_loss = 0.
    count = 0
    batch_tot = len(train_dl)
    for batch_id, (x_con,_) in enumerate(train_dl):
        scheduler.update_lr(e, batch_id, batch_tot)
        
#         if (batch_id % (log_interval*3) == 0):
        try:
            x_style,_ = next(it_style)
            style_image_count += 1
        except:
            it_style = iter(style_dl)
            x_style,_ = next(it_style)
            print('Restarting style')
            style_image_count = 1
            
        with torch.no_grad(): 
            style_batch = x_style.repeat(bs,1,1,1)
            s_out = m_vgg(style_batch)
            style_feat = [s.clone() for s in s_out]
            
            targ_feat = m_vgg(x_con)[c_block].clone()
            
        n_batch = x_con.size(0)
        count += n_batch
        optimizer.zero_grad()
        
        out = m_com(x_con, x_style)
        out,_ = data_norm((out,None))
        inp_feat = m_vgg(out)
        
        closs = [ct_loss(inp_feat[c_block],targ_feat) * ct_wgt]
        sloss = [gram_loss(inp,targ)*wgt for inp,targ,wgt in zip(inp_feat, style_feat, style_wgts) if wgt > 0]
        tvaloss = tva_loss(out) * tva_wgt
        
        total_loss = closs + sloss + [tvaloss]
        total_loss = sum(total_loss)
    
        if fp16:
            optimizer.backward(total_loss)
        else:
            total_loss.backward()
            
#         nn.utils.clip_grad_norm_(m_com.m_tran.parameters(), 10)
#         nn.utils.clip_grad_norm_(m_com.m_style.parameters(), 80)
        optimizer.step()
    
        mom = 0.9
        agg_content_loss = agg_content_loss*mom + sum(closs).detach().data*(1-mom)
        agg_style_loss = agg_style_loss*mom + sum(sloss).detach().data*(1-mom)
        agg_tva_loss = agg_tva_loss*mom + tvaloss.detach().data*(1-mom)
        agg_total_loss = (agg_content_loss + agg_style_loss + agg_tva_loss)
        
        if (batch_id + 1) % log_interval == 0:
            time_elapsed = (time.time() - start)/60
            b_tot = len(train_dl)
            mesg = (f"MIN:{time_elapsed:.2f}\tEP[{e+1}]\tB[{batch_id+1:4}/{b_tot}]\t"
                    f"CON:{agg_content_loss:.3f}\tSTYL:{agg_style_loss:.2f}\t"
                    f"TVA:{agg_tva_loss:.2f}\tTOT:{agg_total_loss:.2f}\t"
#                     f"S/CT:{style_image_count:3}/{count:3}"
                   )
            print(mesg)

Restarting style
Changing LR from 1e-05 to 1.001223867922172e-05
MIN:0.18	EP[1]	B[  50/40037]	CON:21.521	STYL:360.43	TVA:0.27	TOT:382.22	
MIN:0.30	EP[1]	B[ 100/40037]	CON:21.607	STYL:129.86	TVA:0.40	TOT:151.86	
MIN:0.42	EP[1]	B[ 150/40037]	CON:21.065	STYL:166.64	TVA:0.53	TOT:188.23	
MIN:0.53	EP[1]	B[ 200/40037]	CON:21.459	STYL:295.78	TVA:0.67	TOT:317.90	
MIN:0.65	EP[1]	B[ 250/40037]	CON:21.166	STYL:165.36	TVA:0.80	TOT:187.33	
MIN:0.76	EP[1]	B[ 300/40037]	CON:21.221	STYL:161.49	TVA:0.91	TOT:183.63	
MIN:0.88	EP[1]	B[ 350/40037]	CON:21.484	STYL:84.43	TVA:0.99	TOT:106.90	
MIN:0.99	EP[1]	B[ 400/40037]	CON:21.730	STYL:283.69	TVA:1.05	TOT:306.47	
MIN:1.11	EP[1]	B[ 450/40037]	CON:21.404	STYL:200.80	TVA:1.25	TOT:223.45	
Restarting style
MIN:1.23	EP[1]	B[ 500/40037]	CON:21.450	STYL:206.04	TVA:1.37	TOT:228.87	
MIN:1.35	EP[1]	B[ 550/40037]	CON:21.330	STYL:179.09	TVA:1.56	TOT:201.97	
MIN:1.46	EP[1]	B[ 600/40037]	CON:21.096	STYL:120.15	TVA:1.59	TOT:142.83	
MIN:1.58	EP[1]	B[ 650/40037]	CON:21.560	STY

KeyboardInterrupt: 

In [18]:

start = time.time()
m_com.train()
style_image_count = 0
for e in range(epochs):
    agg_content_loss = 0.
    agg_style_loss = 0.
    agg_tva_loss = 0.
    count = 0
    batch_tot = len(train_dl)
    for batch_id, (x_con,_) in enumerate(train_dl):
        scheduler.update_lr(e, batch_id, batch_tot)
        
#         if (batch_id % (log_interval*3) == 0):
        try:
            x_style,_ = next(it_style)
            style_image_count += 1
        except:
            it_style = iter(style_dl)
            x_style,_ = next(it_style)
            print('Restarting style')
            style_image_count = 1
            
        with torch.no_grad(): 
            style_batch = x_style.repeat(bs,1,1,1)
            s_out = m_vgg(style_batch)
            style_feat = [s.clone() for s in s_out]
            
            targ_feat = m_vgg(x_con)[c_block].clone()
            
        n_batch = x_con.size(0)
        count += n_batch
        optimizer.zero_grad()
        
        out = m_com(x_con, x_style)
        out,_ = data_norm((out,None))
        inp_feat = m_vgg(out)
        
        closs = [ct_loss(inp_feat[c_block],targ_feat) * ct_wgt]
        sloss = [gram_loss(inp,targ)*wgt for inp,targ,wgt in zip(inp_feat, style_feat, style_wgts) if wgt > 0]
        tvaloss = tva_loss(out) * tva_wgt
        
        total_loss = closs + sloss + [tvaloss]
        total_loss = sum(total_loss)
    
        if fp16:
            optimizer.backward(total_loss)
        else:
            total_loss.backward()
            
#         nn.utils.clip_grad_norm_(m_com.m_tran.parameters(), 10)
#         nn.utils.clip_grad_norm_(m_com.m_style.parameters(), 80)
        optimizer.step()
    
        mom = 0.9
        agg_content_loss = agg_content_loss*mom + sum(closs).detach().data*(1-mom)
        agg_style_loss = agg_style_loss*mom + sum(sloss).detach().data*(1-mom)
        agg_tva_loss = agg_tva_loss*mom + tvaloss.detach().data*(1-mom)
        agg_total_loss = (agg_content_loss + agg_style_loss + agg_tva_loss)
        
        if (batch_id + 1) % log_interval == 0:
            time_elapsed = (time.time() - start)/60
            b_tot = len(train_dl)
            mesg = (f"MIN:{time_elapsed:.2f}\tEP[{e+1}]\tB[{batch_id+1:4}/{b_tot}]\t"
                    f"CON:{agg_content_loss:.3f}\tSTYL:{agg_style_loss:.2f}\t"
                    f"TVA:{agg_tva_loss:.2f}\tTOT:{agg_total_loss:.2f}\t"
#                     f"S/CT:{style_image_count:3}/{count:3}"
                   )
            print(mesg)

Restarting style
Changing LR from 1e-05 to 1.001223867922172e-05
MIN:0.23	EP[1]	B[  50/40037]	CON:21.390	STYL:208.79	TVA:0.55	TOT:230.73	
MIN:0.35	EP[1]	B[ 100/40037]	CON:21.632	STYL:221.01	TVA:0.75	TOT:243.39	
MIN:0.48	EP[1]	B[ 150/40037]	CON:21.328	STYL:141.67	TVA:0.93	TOT:163.93	
MIN:0.61	EP[1]	B[ 200/40037]	CON:21.494	STYL:113.39	TVA:1.16	TOT:136.04	
MIN:0.73	EP[1]	B[ 250/40037]	CON:21.622	STYL:87.02	TVA:1.36	TOT:110.01	
MIN:0.86	EP[1]	B[ 300/40037]	CON:21.507	STYL:151.13	TVA:1.49	TOT:174.12	
MIN:0.99	EP[1]	B[ 350/40037]	CON:21.483	STYL:135.05	TVA:1.68	TOT:158.21	
MIN:1.12	EP[1]	B[ 400/40037]	CON:21.700	STYL:134.94	TVA:1.77	TOT:158.40	
MIN:1.24	EP[1]	B[ 450/40037]	CON:21.594	STYL:81.95	TVA:1.91	TOT:105.45	
Restarting style
MIN:1.38	EP[1]	B[ 500/40037]	CON:21.775	STYL:121.79	TVA:1.98	TOT:145.54	
MIN:1.50	EP[1]	B[ 550/40037]	CON:21.886	STYL:155.54	TVA:1.94	TOT:179.36	
MIN:1.63	EP[1]	B[ 600/40037]	CON:22.104	STYL:398.23	TVA:2.19	TOT:422.52	


KeyboardInterrupt: 

In [None]:
def eval_imgs(x_con, x_style, idx=0):
    with torch.no_grad(): 
        out = m_com(x_con, x_style)
    fig, axs = plt.subplots(1,3,figsize=(12,4))
    Image(data_denorm(x_con[idx].cpu())).show(axs[0])
    axs[0].set_title('Content')
    axs[1].set_title('Style')
    axs[2].set_title('Transfer')
    Image(data_denorm(x_style[0].cpu())).show(axs[1])
    Image((out[idx].detach().cpu())).show(axs[2])
    

In [None]:
x_con,_ = next(iter(train_dl))

In [None]:
x_style,_ = next(iter(style_dl))
eval_imgs(x_con, x_style)

In [None]:
x_style,_ = next(iter(style_dl))
eval_imgs(x_con, x_style)