In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai import *
from fastai.vision import *
torch.backends.cudnn.benchmark=True
import time

from adamw import AdamW
from scheduler import Scheduler
from models import *

In [3]:
PATH = Path('../fastai_docs/data')
MODEL_PATH = PATH/'models'
MODEL_PATH.mkdir(exist_ok=True)
save_tag = 'imagenet_magenta'

In [4]:

IMAGENET_PATH = Path('../data/imagenet-sz/160/train')
train_ds = ImageClassificationDataset.from_folder(IMAGENET_PATH)

# size,bs = 96,36
size,bs = 128,32
# size,bs = 256,20

# Content Data
data_norm,data_denorm = normalize_funcs(*imagenet_stats)
train_tds = DatasetTfm(train_ds, tfms=[crop_pad(size=size, is_random=False)], tfm_y=False, size=size, do_crop=True)


data = DataBunch.create(train_tds, valid_ds=None, bs=bs, tfms=data_norm)
data.valid_dl = None

# Style Data

STYLE_PATH = PATH/'style/dtd/images'
# STYLE_PATH = PATH/'style/dtd/subset'
style_ds = ImageClassificationDataset.from_folder(STYLE_PATH)

# STYLE_PATH = PATH/'style/pbn/train'
# style_ds = ImageClassificationDataset.from_single_folder(STYLE_PATH, ['train'])

style_tds = DatasetTfm(style_ds, tfms=[crop_pad(size=size, is_random=False)], tfm_y=False, size=size, do_crop=True)
style_data = DataBunch.create(style_tds, valid_ds=None, bs=1, tfms=data_norm)
style_data.valid_dl = None

In [5]:
# losses
def ct_loss(input, target): return F.mse_loss(input,target)

def gram(input):
        b,c,h,w = input.size()
        x = input.view(b, c, -1)
        return torch.bmm(x, x.transpose(1,2))/(c*h*w)

def gram_loss(input, target): return F.mse_loss(gram(input), gram(target))

def tva_loss(y):
    w_var = torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:]))
    h_var = torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))
    return w_var + h_var

## Create models

In [6]:
mt = StyleTransformer()
ms = StylePredict.create_inception()
m_com = CombinedModel(mt, ms).cuda()

In [7]:
m_vgg = VGGActivations().cuda()

Layer ids:  [12, 22, 32, 42]


### Set params

In [8]:
epochs = 3
log_interval = 50

optimizer = AdamW(m_com.parameters(), lr=5e-3, betas=(0.9,0.999), weight_decay=1e-5)

In [9]:
scheduler = Scheduler(optimizer, [{'ep': (0,1), 'lr': (1e-5,5e-4)}, 
                                  {'ep': (1,2), 'lr': (5e-4,1e-5)},  
                                  {'ep': (2,epochs), 'lr': (1e-5,1e-7)}])

In [10]:
# m_com.load_state_dict(torch.load(MODEL_PATH/f'model_combined_4_imagenet_256.pth'), strict=False)
# m_com.load_state_dict(torch.load(MODEL_PATH/f'model_combined_{save_tag}.pth'), strict=True)

In [11]:
style_wgts = [i*1e9 for i in [10,50,10,1]] # 2,3,4,5
c_block = 1 # 1=3
ct_wgt = 5e2
tva_wgt = 5e-6

In [None]:

start = time.time()
m_com.train()
style_image_count = 0
for e in range(epochs):
    agg_content_loss = 0.
    agg_style_loss = 0.
    agg_tva_loss = 0.
    count = 0
    batch_tot = len(data.train_dl)
    for batch_id, (x_con,_) in enumerate(data.train_dl):
        scheduler.update_lr(e, batch_id, batch_tot)
        
        if (batch_id % (log_interval*3) == 0):
            try:
                x_style,_ = next(it_style)
            except:
                it_style = iter(style_data.train_dl)
                x_style,_ = next(it_style)
            style_image_count += 1
            
        with torch.no_grad(): 
            style_batch = x_style.repeat(bs,1,1,1)
            s_out = m_vgg(style_batch)
            style_feat = [s.clone() for s in s_out]
            
            targ_feat = m_vgg(x_con)[c_block].clone()
            
        n_batch = x_con.size(0)
        count += n_batch
        optimizer.zero_grad()
        
        out = m_com(x_con, x_style)
        out,_ = data_norm((out,None))
        inp_feat = m_vgg(out)
        
        closs = [ct_loss(inp_feat[c_block],targ_feat) * ct_wgt]
        sloss = [gram_loss(inp,targ)*wgt for inp,targ,wgt in zip(inp_feat, style_feat, style_wgts) if wgt > 0]
        tvaloss = tva_loss(out) * tva_wgt
        
        total_loss = closs + sloss + [tvaloss]
        total_loss = sum(total_loss)
    
        total_loss.backward()
#         nn.utils.clip_grad_norm_(m_com.m_tran.parameters(), 10)
#         nn.utils.clip_grad_norm_(m_com.m_style.parameters(), 80)
        optimizer.step()
    
        mom = 0.9
        agg_content_loss = agg_content_loss*mom + sum(closs).detach().data*(1-mom)
        agg_style_loss = agg_style_loss*mom + sum(sloss).detach().data*(1-mom)
        agg_tva_loss = agg_tva_loss*mom + tvaloss.detach().data*(1-mom)
        agg_total_loss = (agg_content_loss + agg_style_loss + agg_tva_loss)
        
        if (batch_id + 1) % log_interval == 0:
            time_elapsed = (time.time() - start)/60
            mesg = "Min:{:.3f}\tEp{}: [{}/{}]\tcontent: {:.3f}\tstyle: {:.3f}\ttva: {:.3f}\ttotal: {:.3f}\t simg: {}".format(
                time_elapsed, e + 1, count, len(data.train_dl)*bs,
                              agg_content_loss,
                              agg_style_loss,
                              agg_tva_loss,
                              agg_total_loss,
                              style_image_count
            )
            print(mesg)
        
            

Changing LR from 1e-05 to 1.001223867922172e-05
Min:0.189	Ep1: [1600/1281184]	content: 2.180	style: 210.209	tva: 0.633	total: 213.023	 simg: 1
Min:0.303	Ep1: [3200/1281184]	content: 2.248	style: 175.799	tva: 0.842	total: 178.890	 simg: 1
Min:0.417	Ep1: [4800/1281184]	content: 2.305	style: 155.502	tva: 1.022	total: 158.828	 simg: 1
Min:0.532	Ep1: [6400/1281184]	content: 2.233	style: 109.204	tva: 0.944	total: 112.382	 simg: 2
Min:0.646	Ep1: [8000/1281184]	content: 2.328	style: 94.789	tva: 1.064	total: 98.180	 simg: 2
Min:0.760	Ep1: [9600/1281184]	content: 2.430	style: 84.786	tva: 1.190	total: 88.405	 simg: 2
Min:0.875	Ep1: [11200/1281184]	content: 2.370	style: 46.975	tva: 1.260	total: 50.605	 simg: 3
Min:0.989	Ep1: [12800/1281184]	content: 2.384	style: 42.146	tva: 1.269	total: 45.799	 simg: 3
Min:1.103	Ep1: [14400/1281184]	content: 2.423	style: 38.501	tva: 1.294	total: 42.218	 simg: 3
Min:1.218	Ep1: [16000/1281184]	content: 2.221	style: 80.882	tva: 2.114	total: 85.217	 simg: 4
Min:1.332	

In [None]:
def eval_imgs(x_con, x_style, idx=0)
    with torch.no_grad(): 
        out = m_com(x_con, x_style)
    Image((out[idx].detach().cpu())).show()
    Image(data_denorm(x_con[idx].cpu())).show()
    